Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.dialects.dialect import DialectType
 19    from sqlglot.expressions import Expression
 20
 21
 22CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 23PYTHON_VERSION = sys.version_info[:2]
 24logger = logging.getLogger("sqlglot")
 25
 26
 27class AutoName(Enum):
 28    """
 29    This is used for creating Enum classes where `auto()` is the string form
 30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 31
 32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 33    """
 34
 35    def _generate_next_value_(name, _start, _count, _last_values):
 36        return name
 37
 38
 39class classproperty(property):
 40    """
 41    Similar to a normal property but works for class methods
 42    """
 43
 44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 46
 47
 48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 50    try:
 51        return seq[index]
 52    except IndexError:
 53        return None
 54
 55
 56@t.overload
 57def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 58
 59
 60@t.overload
 61def ensure_list(value: None) -> t.List: ...
 62
 63
 64@t.overload
 65def ensure_list(value: T) -> t.List[T]: ...
 66
 67
 68def ensure_list(value):
 69    """
 70    Ensures that a value is a list, otherwise casts or wraps it into one.
 71
 72    Args:
 73        value: The value of interest.
 74
 75    Returns:
 76        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 77    """
 78    if value is None:
 79        return []
 80    if isinstance(value, (list, tuple)):
 81        return list(value)
 82
 83    return [value]
 84
 85
 86@t.overload
 87def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
 88
 89
 90@t.overload
 91def ensure_collection(value: T) -> t.Collection[T]: ...
 92
 93
 94def ensure_collection(value):
 95    """
 96    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 97
 98    Args:
 99        value: The value of interest.
100
101    Returns:
102        The value if it's a collection, or else the value wrapped in a list.
103    """
104    if value is None:
105        return []
106    return (
107        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
108    )
109
110
111def csv(*args: str, sep: str = ", ") -> str:
112    """
113    Formats any number of string arguments as CSV.
114
115    Args:
116        args: The string arguments to format.
117        sep: The argument separator.
118
119    Returns:
120        The arguments formatted as a CSV string.
121    """
122    return sep.join(arg for arg in args if arg)
123
124
125def subclasses(
126    module_name: str,
127    classes: t.Type | t.Tuple[t.Type, ...],
128    exclude: t.Type | t.Tuple[t.Type, ...] = (),
129) -> t.List[t.Type]:
130    """
131    Returns all subclasses for a collection of classes, possibly excluding some of them.
132
133    Args:
134        module_name: The name of the module to search for subclasses in.
135        classes: Class(es) we want to find the subclasses of.
136        exclude: Class(es) we want to exclude from the returned list.
137
138    Returns:
139        The target subclasses.
140    """
141    return [
142        obj
143        for _, obj in inspect.getmembers(
144            sys.modules[module_name],
145            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
146        )
147    ]
148
149
150def apply_index_offset(
151    this: exp.Expression,
152    expressions: t.List[E],
153    offset: int,
154    dialect: DialectType = None,
155) -> t.List[E]:
156    """
157    Applies an offset to a given integer literal expression.
158
159    Args:
160        this: The target of the index.
161        expressions: The expression the offset will be applied to, wrapped in a list.
162        offset: The offset that will be applied.
163        dialect: the dialect of interest.
164
165    Returns:
166        The original expression with the offset applied to it, wrapped in a list. If the provided
167        `expressions` argument contains more than one expression, it's returned unaffected.
168    """
169    if not offset or len(expressions) != 1:
170        return expressions
171
172    expression = expressions[0]
173
174    from sqlglot import exp
175    from sqlglot.optimizer.annotate_types import annotate_types
176    from sqlglot.optimizer.simplify import simplify
177
178    if not this.type:
179        annotate_types(this, dialect=dialect)
180
181    if t.cast(exp.DataType, this.type).this not in (
182        exp.DataType.Type.UNKNOWN,
183        exp.DataType.Type.ARRAY,
184    ):
185        return expressions
186
187    if not expression.type:
188        annotate_types(expression, dialect=dialect)
189
190    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
191        logger.info("Applying array index offset (%s)", offset)
192        expression = simplify(expression + offset)
193        return [expression]
194
195    return expressions
196
197
198def camel_to_snake_case(name: str) -> str:
199    """Converts `name` from camelCase to snake_case and returns the result."""
200    return CAMEL_CASE_PATTERN.sub("_", name).upper()
201
202
203def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
204    """
205    Applies a transformation to a given expression until a fix point is reached.
206
207    Args:
208        expression: The expression to be transformed.
209        func: The transformation to be applied.
210
211    Returns:
212        The transformed expression.
213    """
214    end_hash: t.Optional[int] = None
215
216    while True:
217        # No need to walk the AST– we've already cached the hashes in the previous iteration
218        if end_hash is None:
219            for n in reversed(tuple(expression.walk())):
220                n._hash = hash(n)
221
222        start_hash = hash(expression)
223        expression = func(expression)
224
225        expression_nodes = tuple(expression.walk())
226
227        # Uncache previous caches so we can recompute them
228        for n in reversed(expression_nodes):
229            n._hash = None
230            n._hash = hash(n)
231
232        end_hash = hash(expression)
233
234        if start_hash == end_hash:
235            # ... and reset the hash so we don't risk it becoming out of date if a mutation happens
236            for n in expression_nodes:
237                n._hash = None
238
239            break
240
241    return expression
242
243
244def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
245    """
246    Sorts a given directed acyclic graph in topological order.
247
248    Args:
249        dag: The graph to be sorted.
250
251    Returns:
252        A list that contains all of the graph's nodes in topological order.
253    """
254    result = []
255
256    for node, deps in tuple(dag.items()):
257        for dep in deps:
258            if dep not in dag:
259                dag[dep] = set()
260
261    while dag:
262        current = {node for node, deps in dag.items() if not deps}
263
264        if not current:
265            raise ValueError("Cycle error")
266
267        for node in current:
268            dag.pop(node)
269
270        for deps in dag.values():
271            deps -= current
272
273        result.extend(sorted(current))  # type: ignore
274
275    return result
276
277
278def open_file(file_name: str) -> t.TextIO:
279    """Open a file that may be compressed as gzip and return it in universal newline mode."""
280    with open(file_name, "rb") as f:
281        gzipped = f.read(2) == b"\x1f\x8b"
282
283    if gzipped:
284        import gzip
285
286        return gzip.open(file_name, "rt", newline="")
287
288    return open(file_name, encoding="utf-8", newline="")
289
290
291@contextmanager
292def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
293    """
294    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
295
296    Args:
297        read_csv: A `ReadCSV` function call.
298
299    Yields:
300        A python csv reader.
301    """
302    args = read_csv.expressions
303    file = open_file(read_csv.name)
304
305    delimiter = ","
306    args = iter(arg.name for arg in args)  # type: ignore
307    for k, v in zip(args, args):
308        if k == "delimiter":
309            delimiter = v
310
311    try:
312        import csv as csv_
313
314        yield csv_.reader(file, delimiter=delimiter)
315    finally:
316        file.close()
317
318
319def find_new_name(taken: t.Collection[str], base: str) -> str:
320    """
321    Searches for a new name.
322
323    Args:
324        taken: A collection of taken names.
325        base: Base name to alter.
326
327    Returns:
328        The new, available name.
329    """
330    if base not in taken:
331        return base
332
333    i = 2
334    new = f"{base}_{i}"
335    while new in taken:
336        i += 1
337        new = f"{base}_{i}"
338
339    return new
340
341
342def is_int(text: str) -> bool:
343    return is_type(text, int)
344
345
346def is_float(text: str) -> bool:
347    return is_type(text, float)
348
349
350def is_type(text: str, target_type: t.Type) -> bool:
351    try:
352        target_type(text)
353        return True
354    except ValueError:
355        return False
356
357
358def name_sequence(prefix: str) -> t.Callable[[], str]:
359    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
360    sequence = count()
361    return lambda: f"{prefix}{next(sequence)}"
362
363
364def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
365    """Returns a dictionary created from an object's attributes."""
366    return {
367        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
368        **kwargs,
369    }
370
371
372def split_num_words(
373    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
374) -> t.List[t.Optional[str]]:
375    """
376    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
377
378    Args:
379        value: The value to be split.
380        sep: The value to use to split on.
381        min_num_words: The minimum number of words that are going to be in the result.
382        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
383
384    Examples:
385        >>> split_num_words("db.table", ".", 3)
386        [None, 'db', 'table']
387        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
388        ['db', 'table', None]
389        >>> split_num_words("db.table", ".", 1)
390        ['db', 'table']
391
392    Returns:
393        The list of words returned by `split`, possibly augmented by a number of `None` values.
394    """
395    words = value.split(sep)
396    if fill_from_start:
397        return [None] * (min_num_words - len(words)) + words
398    return words + [None] * (min_num_words - len(words))
399
400
401def is_iterable(value: t.Any) -> bool:
402    """
403    Checks if the value is an iterable, excluding the types `str` and `bytes`.
404
405    Examples:
406        >>> is_iterable([1,2])
407        True
408        >>> is_iterable("test")
409        False
410
411    Args:
412        value: The value to check if it is an iterable.
413
414    Returns:
415        A `bool` value indicating if it is an iterable.
416    """
417    from sqlglot import Expression
418
419    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
420
421
422def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
423    """
424    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
425    type `str` and `bytes` are not regarded as iterables.
426
427    Examples:
428        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
429        [1, 2, 3, 4, 5, 'bla']
430        >>> list(flatten([1, 2, 3]))
431        [1, 2, 3]
432
433    Args:
434        values: The value to be flattened.
435
436    Yields:
437        Non-iterable elements in `values`.
438    """
439    for value in values:
440        if is_iterable(value):
441            yield from flatten(value)
442        else:
443            yield value
444
445
446def dict_depth(d: t.Dict) -> int:
447    """
448    Get the nesting depth of a dictionary.
449
450    Example:
451        >>> dict_depth(None)
452        0
453        >>> dict_depth({})
454        1
455        >>> dict_depth({"a": "b"})
456        1
457        >>> dict_depth({"a": {}})
458        2
459        >>> dict_depth({"a": {"b": {}}})
460        3
461    """
462    try:
463        return 1 + dict_depth(next(iter(d.values())))
464    except AttributeError:
465        # d doesn't have attribute "values"
466        return 0
467    except StopIteration:
468        # d.values() returns an empty sequence
469        return 1
470
471
472def first(it: t.Iterable[T]) -> T:
473    """Returns the first element from an iterable (useful for sets)."""
474    return next(i for i in it)
475
476
477def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
478    if isinstance(value, bool) or value is None:
479        return value
480
481    # Coerce the value to boolean if it matches to the truthy/falsy values below
482    value_lower = value.lower()
483    if value_lower in ("true", "1"):
484        return True
485    if value_lower in ("false", "0"):
486        return False
487
488    return value
489
490
491def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
492    """
493    Merges a sequence of ranges, represented as tuples (low, high) whose values
494    belong to some totally-ordered set.
495
496    Example:
497        >>> merge_ranges([(1, 3), (2, 6)])
498        [(1, 6)]
499    """
500    if not ranges:
501        return []
502
503    ranges = sorted(ranges)
504
505    merged = [ranges[0]]
506
507    for start, end in ranges[1:]:
508        last_start, last_end = merged[-1]
509
510        if start <= last_end:
511            merged[-1] = (last_start, max(last_end, end))
512        else:
513            merged.append((start, end))
514
515    return merged
516
517
518def is_iso_date(text: str) -> bool:
519    try:
520        datetime.date.fromisoformat(text)
521        return True
522    except ValueError:
523        return False
524
525
526def is_iso_datetime(text: str) -> bool:
527    try:
528        datetime.datetime.fromisoformat(text)
529        return True
530    except ValueError:
531        return False
532
533
534# Interval units that operate on date components
535DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
536
537
538def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
539    return expression is not None and expression.name.lower() in DATE_UNITS
540
541
542K = t.TypeVar("K")
543V = t.TypeVar("V")
544
545
546class SingleValuedMapping(t.Mapping[K, V]):
547    """
548    Mapping where all keys return the same value.
549
550    This rigamarole is meant to avoid copying keys, which was originally intended
551    as an optimization while qualifying columns for tables with lots of columns.
552    """
553
554    def __init__(self, keys: t.Collection[K], value: V):
555        self._keys = keys if isinstance(keys, Set) else set(keys)
556        self._value = value
557
558    def __getitem__(self, key: K) -> V:
559        if key in self._keys:
560            return self._value
561        raise KeyError(key)
562
563    def __len__(self) -> int:
564        return len(self._keys)
565
566    def __iter__(self) -> t.Iterator[K]:
567        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
28class AutoName(Enum):
29    """
30    This is used for creating Enum classes where `auto()` is the string form
31    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
32
33    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
34    """
35
36    def _generate_next_value_(name, _start, _count, _last_values):
37        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

class classproperty(builtins.property):
40class classproperty(property):
41    """
42    Similar to a normal property but works for class methods
43    """
44
45    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
46        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
49def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
50    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
51    try:
52        return seq[index]
53    except IndexError:
54        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
69def ensure_list(value):
70    """
71    Ensures that a value is a list, otherwise casts or wraps it into one.
72
73    Args:
74        value: The value of interest.
75
76    Returns:
77        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
78    """
79    if value is None:
80        return []
81    if isinstance(value, (list, tuple)):
82        return list(value)
83
84    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 95def ensure_collection(value):
 96    """
 97    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 98
 99    Args:
100        value: The value of interest.
101
102    Returns:
103        The value if it's a collection, or else the value wrapped in a list.
104    """
105    if value is None:
106        return []
107    return (
108        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
109    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
112def csv(*args: str, sep: str = ", ") -> str:
113    """
114    Formats any number of string arguments as CSV.
115
116    Args:
117        args: The string arguments to format.
118        sep: The argument separator.
119
120    Returns:
121        The arguments formatted as a CSV string.
122    """
123    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
126def subclasses(
127    module_name: str,
128    classes: t.Type | t.Tuple[t.Type, ...],
129    exclude: t.Type | t.Tuple[t.Type, ...] = (),
130) -> t.List[t.Type]:
131    """
132    Returns all subclasses for a collection of classes, possibly excluding some of them.
133
134    Args:
135        module_name: The name of the module to search for subclasses in.
136        classes: Class(es) we want to find the subclasses of.
137        exclude: Class(es) we want to exclude from the returned list.
138
139    Returns:
140        The target subclasses.
141    """
142    return [
143        obj
144        for _, obj in inspect.getmembers(
145            sys.modules[module_name],
146            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
147        )
148    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> List[~E]:
151def apply_index_offset(
152    this: exp.Expression,
153    expressions: t.List[E],
154    offset: int,
155    dialect: DialectType = None,
156) -> t.List[E]:
157    """
158    Applies an offset to a given integer literal expression.
159
160    Args:
161        this: The target of the index.
162        expressions: The expression the offset will be applied to, wrapped in a list.
163        offset: The offset that will be applied.
164        dialect: the dialect of interest.
165
166    Returns:
167        The original expression with the offset applied to it, wrapped in a list. If the provided
168        `expressions` argument contains more than one expression, it's returned unaffected.
169    """
170    if not offset or len(expressions) != 1:
171        return expressions
172
173    expression = expressions[0]
174
175    from sqlglot import exp
176    from sqlglot.optimizer.annotate_types import annotate_types
177    from sqlglot.optimizer.simplify import simplify
178
179    if not this.type:
180        annotate_types(this, dialect=dialect)
181
182    if t.cast(exp.DataType, this.type).this not in (
183        exp.DataType.Type.UNKNOWN,
184        exp.DataType.Type.ARRAY,
185    ):
186        return expressions
187
188    if not expression.type:
189        annotate_types(expression, dialect=dialect)
190
191    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
192        logger.info("Applying array index offset (%s)", offset)
193        expression = simplify(expression + offset)
194        return [expression]
195
196    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
  • dialect: the dialect of interest.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
199def camel_to_snake_case(name: str) -> str:
200    """Converts `name` from camelCase to snake_case and returns the result."""
201    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
204def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
205    """
206    Applies a transformation to a given expression until a fix point is reached.
207
208    Args:
209        expression: The expression to be transformed.
210        func: The transformation to be applied.
211
212    Returns:
213        The transformed expression.
214    """
215    end_hash: t.Optional[int] = None
216
217    while True:
218        # No need to walk the AST– we've already cached the hashes in the previous iteration
219        if end_hash is None:
220            for n in reversed(tuple(expression.walk())):
221                n._hash = hash(n)
222
223        start_hash = hash(expression)
224        expression = func(expression)
225
226        expression_nodes = tuple(expression.walk())
227
228        # Uncache previous caches so we can recompute them
229        for n in reversed(expression_nodes):
230            n._hash = None
231            n._hash = hash(n)
232
233        end_hash = hash(expression)
234
235        if start_hash == end_hash:
236            # ... and reset the hash so we don't risk it becoming out of date if a mutation happens
237            for n in expression_nodes:
238                n._hash = None
239
240            break
241
242    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
245def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
246    """
247    Sorts a given directed acyclic graph in topological order.
248
249    Args:
250        dag: The graph to be sorted.
251
252    Returns:
253        A list that contains all of the graph's nodes in topological order.
254    """
255    result = []
256
257    for node, deps in tuple(dag.items()):
258        for dep in deps:
259            if dep not in dag:
260                dag[dep] = set()
261
262    while dag:
263        current = {node for node, deps in dag.items() if not deps}
264
265        if not current:
266            raise ValueError("Cycle error")
267
268        for node in current:
269            dag.pop(node)
270
271        for deps in dag.values():
272            deps -= current
273
274        result.extend(sorted(current))  # type: ignore
275
276    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
279def open_file(file_name: str) -> t.TextIO:
280    """Open a file that may be compressed as gzip and return it in universal newline mode."""
281    with open(file_name, "rb") as f:
282        gzipped = f.read(2) == b"\x1f\x8b"
283
284    if gzipped:
285        import gzip
286
287        return gzip.open(file_name, "rt", newline="")
288
289    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
292@contextmanager
293def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
294    """
295    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
296
297    Args:
298        read_csv: A `ReadCSV` function call.
299
300    Yields:
301        A python csv reader.
302    """
303    args = read_csv.expressions
304    file = open_file(read_csv.name)
305
306    delimiter = ","
307    args = iter(arg.name for arg in args)  # type: ignore
308    for k, v in zip(args, args):
309        if k == "delimiter":
310            delimiter = v
311
312    try:
313        import csv as csv_
314
315        yield csv_.reader(file, delimiter=delimiter)
316    finally:
317        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
320def find_new_name(taken: t.Collection[str], base: str) -> str:
321    """
322    Searches for a new name.
323
324    Args:
325        taken: A collection of taken names.
326        base: Base name to alter.
327
328    Returns:
329        The new, available name.
330    """
331    if base not in taken:
332        return base
333
334    i = 2
335    new = f"{base}_{i}"
336    while new in taken:
337        i += 1
338        new = f"{base}_{i}"
339
340    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
343def is_int(text: str) -> bool:
344    return is_type(text, int)
def is_float(text: str) -> bool:
347def is_float(text: str) -> bool:
348    return is_type(text, float)
def is_type(text: str, target_type: Type) -> bool:
351def is_type(text: str, target_type: t.Type) -> bool:
352    try:
353        target_type(text)
354        return True
355    except ValueError:
356        return False
def name_sequence(prefix: str) -> Callable[[], str]:
359def name_sequence(prefix: str) -> t.Callable[[], str]:
360    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
361    sequence = count()
362    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
365def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
366    """Returns a dictionary created from an object's attributes."""
367    return {
368        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
369        **kwargs,
370    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
373def split_num_words(
374    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
375) -> t.List[t.Optional[str]]:
376    """
377    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
378
379    Args:
380        value: The value to be split.
381        sep: The value to use to split on.
382        min_num_words: The minimum number of words that are going to be in the result.
383        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
384
385    Examples:
386        >>> split_num_words("db.table", ".", 3)
387        [None, 'db', 'table']
388        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
389        ['db', 'table', None]
390        >>> split_num_words("db.table", ".", 1)
391        ['db', 'table']
392
393    Returns:
394        The list of words returned by `split`, possibly augmented by a number of `None` values.
395    """
396    words = value.split(sep)
397    if fill_from_start:
398        return [None] * (min_num_words - len(words)) + words
399    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
402def is_iterable(value: t.Any) -> bool:
403    """
404    Checks if the value is an iterable, excluding the types `str` and `bytes`.
405
406    Examples:
407        >>> is_iterable([1,2])
408        True
409        >>> is_iterable("test")
410        False
411
412    Args:
413        value: The value to check if it is an iterable.
414
415    Returns:
416        A `bool` value indicating if it is an iterable.
417    """
418    from sqlglot import Expression
419
420    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
423def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
424    """
425    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
426    type `str` and `bytes` are not regarded as iterables.
427
428    Examples:
429        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
430        [1, 2, 3, 4, 5, 'bla']
431        >>> list(flatten([1, 2, 3]))
432        [1, 2, 3]
433
434    Args:
435        values: The value to be flattened.
436
437    Yields:
438        Non-iterable elements in `values`.
439    """
440    for value in values:
441        if is_iterable(value):
442            yield from flatten(value)
443        else:
444            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
447def dict_depth(d: t.Dict) -> int:
448    """
449    Get the nesting depth of a dictionary.
450
451    Example:
452        >>> dict_depth(None)
453        0
454        >>> dict_depth({})
455        1
456        >>> dict_depth({"a": "b"})
457        1
458        >>> dict_depth({"a": {}})
459        2
460        >>> dict_depth({"a": {"b": {}}})
461        3
462    """
463    try:
464        return 1 + dict_depth(next(iter(d.values())))
465    except AttributeError:
466        # d doesn't have attribute "values"
467        return 0
468    except StopIteration:
469        # d.values() returns an empty sequence
470        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
473def first(it: t.Iterable[T]) -> T:
474    """Returns the first element from an iterable (useful for sets)."""
475    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def to_bool(value: Union[str, bool, NoneType]) -> Union[str, bool, NoneType]:
478def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
479    if isinstance(value, bool) or value is None:
480        return value
481
482    # Coerce the value to boolean if it matches to the truthy/falsy values below
483    value_lower = value.lower()
484    if value_lower in ("true", "1"):
485        return True
486    if value_lower in ("false", "0"):
487        return False
488
489    return value
def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
492def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
493    """
494    Merges a sequence of ranges, represented as tuples (low, high) whose values
495    belong to some totally-ordered set.
496
497    Example:
498        >>> merge_ranges([(1, 3), (2, 6)])
499        [(1, 6)]
500    """
501    if not ranges:
502        return []
503
504    ranges = sorted(ranges)
505
506    merged = [ranges[0]]
507
508    for start, end in ranges[1:]:
509        last_start, last_end = merged[-1]
510
511        if start <= last_end:
512            merged[-1] = (last_start, max(last_end, end))
513        else:
514            merged.append((start, end))
515
516    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
519def is_iso_date(text: str) -> bool:
520    try:
521        datetime.date.fromisoformat(text)
522        return True
523    except ValueError:
524        return False
def is_iso_datetime(text: str) -> bool:
527def is_iso_datetime(text: str) -> bool:
528    try:
529        datetime.datetime.fromisoformat(text)
530        return True
531    except ValueError:
532        return False
DATE_UNITS = {'quarter', 'day', 'year_month', 'week', 'month', 'year'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
539def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
540    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
547class SingleValuedMapping(t.Mapping[K, V]):
548    """
549    Mapping where all keys return the same value.
550
551    This rigamarole is meant to avoid copying keys, which was originally intended
552    as an optimization while qualifying columns for tables with lots of columns.
553    """
554
555    def __init__(self, keys: t.Collection[K], value: V):
556        self._keys = keys if isinstance(keys, Set) else set(keys)
557        self._value = value
558
559    def __getitem__(self, key: K) -> V:
560        if key in self._keys:
561            return self._value
562        raise KeyError(key)
563
564    def __len__(self) -> int:
565        return len(self._keys)
566
567    def __iter__(self) -> t.Iterator[K]:
568        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
555    def __init__(self, keys: t.Collection[K], value: V):
556        self._keys = keys if isinstance(keys, Set) else set(keys)
557        self._value = value