Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from difflib import get_close_matches
 13from enum import Enum
 14from itertools import count
 15
 16if t.TYPE_CHECKING:
 17    from sqlglot import exp
 18    from sqlglot._typing import A, E, T
 19    from sqlglot.dialects.dialect import DialectType
 20    from sqlglot.expressions import Expression
 21
 22
 23CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 24PYTHON_VERSION = sys.version_info[:2]
 25logger = logging.getLogger("sqlglot")
 26
 27
 28class AutoName(Enum):
 29    """
 30    This is used for creating Enum classes where `auto()` is the string form
 31    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 32
 33    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 34    """
 35
 36    def _generate_next_value_(name, _start, _count, _last_values):
 37        return name
 38
 39
 40class classproperty(property):
 41    """
 42    Similar to a normal property but works for class methods
 43    """
 44
 45    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 46        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 47
 48
 49def suggest_closest_match_and_fail(
 50    kind: str,
 51    word: str,
 52    possibilities: t.Iterable[str],
 53) -> None:
 54    close_matches = get_close_matches(word, possibilities, n=1)
 55
 56    similar = seq_get(close_matches, 0) or ""
 57    if similar:
 58        similar = f" Did you mean {similar}?"
 59
 60    raise ValueError(f"Unknown {kind} '{word}'.{similar}")
 61
 62
 63def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 64    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 65    try:
 66        return seq[index]
 67    except IndexError:
 68        return None
 69
 70
 71@t.overload
 72def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 73
 74
 75@t.overload
 76def ensure_list(value: None) -> t.List: ...
 77
 78
 79@t.overload
 80def ensure_list(value: T) -> t.List[T]: ...
 81
 82
 83def ensure_list(value):
 84    """
 85    Ensures that a value is a list, otherwise casts or wraps it into one.
 86
 87    Args:
 88        value: The value of interest.
 89
 90    Returns:
 91        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 92    """
 93    if value is None:
 94        return []
 95    if isinstance(value, (list, tuple)):
 96        return list(value)
 97
 98    return [value]
 99
100
101@t.overload
102def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
103
104
105@t.overload
106def ensure_collection(value: T) -> t.Collection[T]: ...
107
108
109def ensure_collection(value):
110    """
111    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
112
113    Args:
114        value: The value of interest.
115
116    Returns:
117        The value if it's a collection, or else the value wrapped in a list.
118    """
119    if value is None:
120        return []
121    return (
122        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
123    )
124
125
126def csv(*args: str, sep: str = ", ") -> str:
127    """
128    Formats any number of string arguments as CSV.
129
130    Args:
131        args: The string arguments to format.
132        sep: The argument separator.
133
134    Returns:
135        The arguments formatted as a CSV string.
136    """
137    return sep.join(arg for arg in args if arg)
138
139
140def subclasses(
141    module_name: str,
142    classes: t.Type | t.Tuple[t.Type, ...],
143    exclude: t.Type | t.Tuple[t.Type, ...] = (),
144) -> t.List[t.Type]:
145    """
146    Returns all subclasses for a collection of classes, possibly excluding some of them.
147
148    Args:
149        module_name: The name of the module to search for subclasses in.
150        classes: Class(es) we want to find the subclasses of.
151        exclude: Class(es) we want to exclude from the returned list.
152
153    Returns:
154        The target subclasses.
155    """
156    return [
157        obj
158        for _, obj in inspect.getmembers(
159            sys.modules[module_name],
160            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
161        )
162    ]
163
164
165def apply_index_offset(
166    this: exp.Expression,
167    expressions: t.List[E],
168    offset: int,
169    dialect: DialectType = None,
170) -> t.List[E]:
171    """
172    Applies an offset to a given integer literal expression.
173
174    Args:
175        this: The target of the index.
176        expressions: The expression the offset will be applied to, wrapped in a list.
177        offset: The offset that will be applied.
178        dialect: the dialect of interest.
179
180    Returns:
181        The original expression with the offset applied to it, wrapped in a list. If the provided
182        `expressions` argument contains more than one expression, it's returned unaffected.
183    """
184    if not offset or len(expressions) != 1:
185        return expressions
186
187    expression = expressions[0]
188
189    from sqlglot import exp
190    from sqlglot.optimizer.annotate_types import annotate_types
191    from sqlglot.optimizer.simplify import simplify
192
193    if not this.type:
194        annotate_types(this, dialect=dialect)
195
196    if t.cast(exp.DataType, this.type).this not in (
197        exp.DataType.Type.UNKNOWN,
198        exp.DataType.Type.ARRAY,
199    ):
200        return expressions
201
202    if not expression.type:
203        annotate_types(expression, dialect=dialect)
204
205    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
206        logger.info("Applying array index offset (%s)", offset)
207        expression = simplify(expression + offset)
208        return [expression]
209
210    return expressions
211
212
213def camel_to_snake_case(name: str) -> str:
214    """Converts `name` from camelCase to snake_case and returns the result."""
215    return CAMEL_CASE_PATTERN.sub("_", name).upper()
216
217
218def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
219    """
220    Applies a transformation to a given expression until a fix point is reached.
221
222    Args:
223        expression: The expression to be transformed.
224        func: The transformation to be applied.
225
226    Returns:
227        The transformed expression.
228    """
229    end_hash: t.Optional[int] = None
230
231    while True:
232        # No need to walk the AST– we've already cached the hashes in the previous iteration
233        if end_hash is None:
234            for n in reversed(tuple(expression.walk())):
235                n._hash = hash(n)
236
237        start_hash = hash(expression)
238        expression = func(expression)
239
240        expression_nodes = tuple(expression.walk())
241
242        # Uncache previous caches so we can recompute them
243        for n in reversed(expression_nodes):
244            n._hash = None
245            n._hash = hash(n)
246
247        end_hash = hash(expression)
248
249        if start_hash == end_hash:
250            # ... and reset the hash so we don't risk it becoming out of date if a mutation happens
251            for n in expression_nodes:
252                n._hash = None
253
254            break
255
256    return expression
257
258
259def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
260    """
261    Sorts a given directed acyclic graph in topological order.
262
263    Args:
264        dag: The graph to be sorted.
265
266    Returns:
267        A list that contains all of the graph's nodes in topological order.
268    """
269    result = []
270
271    for node, deps in tuple(dag.items()):
272        for dep in deps:
273            if dep not in dag:
274                dag[dep] = set()
275
276    while dag:
277        current = {node for node, deps in dag.items() if not deps}
278
279        if not current:
280            raise ValueError("Cycle error")
281
282        for node in current:
283            dag.pop(node)
284
285        for deps in dag.values():
286            deps -= current
287
288        result.extend(sorted(current))  # type: ignore
289
290    return result
291
292
293def open_file(file_name: str) -> t.TextIO:
294    """Open a file that may be compressed as gzip and return it in universal newline mode."""
295    with open(file_name, "rb") as f:
296        gzipped = f.read(2) == b"\x1f\x8b"
297
298    if gzipped:
299        import gzip
300
301        return gzip.open(file_name, "rt", newline="")
302
303    return open(file_name, encoding="utf-8", newline="")
304
305
306@contextmanager
307def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
308    """
309    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
310
311    Args:
312        read_csv: A `ReadCSV` function call.
313
314    Yields:
315        A python csv reader.
316    """
317    args = read_csv.expressions
318    file = open_file(read_csv.name)
319
320    delimiter = ","
321    args = iter(arg.name for arg in args)  # type: ignore
322    for k, v in zip(args, args):
323        if k == "delimiter":
324            delimiter = v
325
326    try:
327        import csv as csv_
328
329        yield csv_.reader(file, delimiter=delimiter)
330    finally:
331        file.close()
332
333
334def find_new_name(taken: t.Collection[str], base: str) -> str:
335    """
336    Searches for a new name.
337
338    Args:
339        taken: A collection of taken names.
340        base: Base name to alter.
341
342    Returns:
343        The new, available name.
344    """
345    if base not in taken:
346        return base
347
348    i = 2
349    new = f"{base}_{i}"
350    while new in taken:
351        i += 1
352        new = f"{base}_{i}"
353
354    return new
355
356
357def is_int(text: str) -> bool:
358    return is_type(text, int)
359
360
361def is_float(text: str) -> bool:
362    return is_type(text, float)
363
364
365def is_type(text: str, target_type: t.Type) -> bool:
366    try:
367        target_type(text)
368        return True
369    except ValueError:
370        return False
371
372
373def name_sequence(prefix: str) -> t.Callable[[], str]:
374    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
375    sequence = count()
376    return lambda: f"{prefix}{next(sequence)}"
377
378
379def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
380    """Returns a dictionary created from an object's attributes."""
381    return {
382        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
383        **kwargs,
384    }
385
386
387def split_num_words(
388    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
389) -> t.List[t.Optional[str]]:
390    """
391    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
392
393    Args:
394        value: The value to be split.
395        sep: The value to use to split on.
396        min_num_words: The minimum number of words that are going to be in the result.
397        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
398
399    Examples:
400        >>> split_num_words("db.table", ".", 3)
401        [None, 'db', 'table']
402        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
403        ['db', 'table', None]
404        >>> split_num_words("db.table", ".", 1)
405        ['db', 'table']
406
407    Returns:
408        The list of words returned by `split`, possibly augmented by a number of `None` values.
409    """
410    words = value.split(sep)
411    if fill_from_start:
412        return [None] * (min_num_words - len(words)) + words
413    return words + [None] * (min_num_words - len(words))
414
415
416def is_iterable(value: t.Any) -> bool:
417    """
418    Checks if the value is an iterable, excluding the types `str` and `bytes`.
419
420    Examples:
421        >>> is_iterable([1,2])
422        True
423        >>> is_iterable("test")
424        False
425
426    Args:
427        value: The value to check if it is an iterable.
428
429    Returns:
430        A `bool` value indicating if it is an iterable.
431    """
432    from sqlglot import Expression
433
434    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
435
436
437def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
438    """
439    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
440    type `str` and `bytes` are not regarded as iterables.
441
442    Examples:
443        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
444        [1, 2, 3, 4, 5, 'bla']
445        >>> list(flatten([1, 2, 3]))
446        [1, 2, 3]
447
448    Args:
449        values: The value to be flattened.
450
451    Yields:
452        Non-iterable elements in `values`.
453    """
454    for value in values:
455        if is_iterable(value):
456            yield from flatten(value)
457        else:
458            yield value
459
460
461def dict_depth(d: t.Dict) -> int:
462    """
463    Get the nesting depth of a dictionary.
464
465    Example:
466        >>> dict_depth(None)
467        0
468        >>> dict_depth({})
469        1
470        >>> dict_depth({"a": "b"})
471        1
472        >>> dict_depth({"a": {}})
473        2
474        >>> dict_depth({"a": {"b": {}}})
475        3
476    """
477    try:
478        return 1 + dict_depth(next(iter(d.values())))
479    except AttributeError:
480        # d doesn't have attribute "values"
481        return 0
482    except StopIteration:
483        # d.values() returns an empty sequence
484        return 1
485
486
487def first(it: t.Iterable[T]) -> T:
488    """Returns the first element from an iterable (useful for sets)."""
489    return next(i for i in it)
490
491
492def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
493    if isinstance(value, bool) or value is None:
494        return value
495
496    # Coerce the value to boolean if it matches to the truthy/falsy values below
497    value_lower = value.lower()
498    if value_lower in ("true", "1"):
499        return True
500    if value_lower in ("false", "0"):
501        return False
502
503    return value
504
505
506def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
507    """
508    Merges a sequence of ranges, represented as tuples (low, high) whose values
509    belong to some totally-ordered set.
510
511    Example:
512        >>> merge_ranges([(1, 3), (2, 6)])
513        [(1, 6)]
514    """
515    if not ranges:
516        return []
517
518    ranges = sorted(ranges)
519
520    merged = [ranges[0]]
521
522    for start, end in ranges[1:]:
523        last_start, last_end = merged[-1]
524
525        if start <= last_end:
526            merged[-1] = (last_start, max(last_end, end))
527        else:
528            merged.append((start, end))
529
530    return merged
531
532
533def is_iso_date(text: str) -> bool:
534    try:
535        datetime.date.fromisoformat(text)
536        return True
537    except ValueError:
538        return False
539
540
541def is_iso_datetime(text: str) -> bool:
542    try:
543        datetime.datetime.fromisoformat(text)
544        return True
545    except ValueError:
546        return False
547
548
549# Interval units that operate on date components
550DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
551
552
553def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
554    return expression is not None and expression.name.lower() in DATE_UNITS
555
556
557K = t.TypeVar("K")
558V = t.TypeVar("V")
559
560
561class SingleValuedMapping(t.Mapping[K, V]):
562    """
563    Mapping where all keys return the same value.
564
565    This rigamarole is meant to avoid copying keys, which was originally intended
566    as an optimization while qualifying columns for tables with lots of columns.
567    """
568
569    def __init__(self, keys: t.Collection[K], value: V):
570        self._keys = keys if isinstance(keys, Set) else set(keys)
571        self._value = value
572
573    def __getitem__(self, key: K) -> V:
574        if key in self._keys:
575            return self._value
576        raise KeyError(key)
577
578    def __len__(self) -> int:
579        return len(self._keys)
580
581    def __iter__(self) -> t.Iterator[K]:
582        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
29class AutoName(Enum):
30    """
31    This is used for creating Enum classes where `auto()` is the string form
32    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
33
34    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
35    """
36
37    def _generate_next_value_(name, _start, _count, _last_values):
38        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

class classproperty(builtins.property):
41class classproperty(property):
42    """
43    Similar to a normal property but works for class methods
44    """
45
46    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
47        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

def suggest_closest_match_and_fail(kind: str, word: str, possibilities: Iterable[str]) -> None:
50def suggest_closest_match_and_fail(
51    kind: str,
52    word: str,
53    possibilities: t.Iterable[str],
54) -> None:
55    close_matches = get_close_matches(word, possibilities, n=1)
56
57    similar = seq_get(close_matches, 0) or ""
58    if similar:
59        similar = f" Did you mean {similar}?"
60
61    raise ValueError(f"Unknown {kind} '{word}'.{similar}")
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
64def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
65    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
66    try:
67        return seq[index]
68    except IndexError:
69        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
84def ensure_list(value):
85    """
86    Ensures that a value is a list, otherwise casts or wraps it into one.
87
88    Args:
89        value: The value of interest.
90
91    Returns:
92        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
93    """
94    if value is None:
95        return []
96    if isinstance(value, (list, tuple)):
97        return list(value)
98
99    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
110def ensure_collection(value):
111    """
112    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
113
114    Args:
115        value: The value of interest.
116
117    Returns:
118        The value if it's a collection, or else the value wrapped in a list.
119    """
120    if value is None:
121        return []
122    return (
123        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
124    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
127def csv(*args: str, sep: str = ", ") -> str:
128    """
129    Formats any number of string arguments as CSV.
130
131    Args:
132        args: The string arguments to format.
133        sep: The argument separator.
134
135    Returns:
136        The arguments formatted as a CSV string.
137    """
138    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
141def subclasses(
142    module_name: str,
143    classes: t.Type | t.Tuple[t.Type, ...],
144    exclude: t.Type | t.Tuple[t.Type, ...] = (),
145) -> t.List[t.Type]:
146    """
147    Returns all subclasses for a collection of classes, possibly excluding some of them.
148
149    Args:
150        module_name: The name of the module to search for subclasses in.
151        classes: Class(es) we want to find the subclasses of.
152        exclude: Class(es) we want to exclude from the returned list.
153
154    Returns:
155        The target subclasses.
156    """
157    return [
158        obj
159        for _, obj in inspect.getmembers(
160            sys.modules[module_name],
161            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
162        )
163    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> List[~E]:
166def apply_index_offset(
167    this: exp.Expression,
168    expressions: t.List[E],
169    offset: int,
170    dialect: DialectType = None,
171) -> t.List[E]:
172    """
173    Applies an offset to a given integer literal expression.
174
175    Args:
176        this: The target of the index.
177        expressions: The expression the offset will be applied to, wrapped in a list.
178        offset: The offset that will be applied.
179        dialect: the dialect of interest.
180
181    Returns:
182        The original expression with the offset applied to it, wrapped in a list. If the provided
183        `expressions` argument contains more than one expression, it's returned unaffected.
184    """
185    if not offset or len(expressions) != 1:
186        return expressions
187
188    expression = expressions[0]
189
190    from sqlglot import exp
191    from sqlglot.optimizer.annotate_types import annotate_types
192    from sqlglot.optimizer.simplify import simplify
193
194    if not this.type:
195        annotate_types(this, dialect=dialect)
196
197    if t.cast(exp.DataType, this.type).this not in (
198        exp.DataType.Type.UNKNOWN,
199        exp.DataType.Type.ARRAY,
200    ):
201        return expressions
202
203    if not expression.type:
204        annotate_types(expression, dialect=dialect)
205
206    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
207        logger.info("Applying array index offset (%s)", offset)
208        expression = simplify(expression + offset)
209        return [expression]
210
211    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
  • dialect: the dialect of interest.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
214def camel_to_snake_case(name: str) -> str:
215    """Converts `name` from camelCase to snake_case and returns the result."""
216    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
219def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
220    """
221    Applies a transformation to a given expression until a fix point is reached.
222
223    Args:
224        expression: The expression to be transformed.
225        func: The transformation to be applied.
226
227    Returns:
228        The transformed expression.
229    """
230    end_hash: t.Optional[int] = None
231
232    while True:
233        # No need to walk the AST– we've already cached the hashes in the previous iteration
234        if end_hash is None:
235            for n in reversed(tuple(expression.walk())):
236                n._hash = hash(n)
237
238        start_hash = hash(expression)
239        expression = func(expression)
240
241        expression_nodes = tuple(expression.walk())
242
243        # Uncache previous caches so we can recompute them
244        for n in reversed(expression_nodes):
245            n._hash = None
246            n._hash = hash(n)
247
248        end_hash = hash(expression)
249
250        if start_hash == end_hash:
251            # ... and reset the hash so we don't risk it becoming out of date if a mutation happens
252            for n in expression_nodes:
253                n._hash = None
254
255            break
256
257    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
260def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
261    """
262    Sorts a given directed acyclic graph in topological order.
263
264    Args:
265        dag: The graph to be sorted.
266
267    Returns:
268        A list that contains all of the graph's nodes in topological order.
269    """
270    result = []
271
272    for node, deps in tuple(dag.items()):
273        for dep in deps:
274            if dep not in dag:
275                dag[dep] = set()
276
277    while dag:
278        current = {node for node, deps in dag.items() if not deps}
279
280        if not current:
281            raise ValueError("Cycle error")
282
283        for node in current:
284            dag.pop(node)
285
286        for deps in dag.values():
287            deps -= current
288
289        result.extend(sorted(current))  # type: ignore
290
291    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
294def open_file(file_name: str) -> t.TextIO:
295    """Open a file that may be compressed as gzip and return it in universal newline mode."""
296    with open(file_name, "rb") as f:
297        gzipped = f.read(2) == b"\x1f\x8b"
298
299    if gzipped:
300        import gzip
301
302        return gzip.open(file_name, "rt", newline="")
303
304    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
307@contextmanager
308def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
309    """
310    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
311
312    Args:
313        read_csv: A `ReadCSV` function call.
314
315    Yields:
316        A python csv reader.
317    """
318    args = read_csv.expressions
319    file = open_file(read_csv.name)
320
321    delimiter = ","
322    args = iter(arg.name for arg in args)  # type: ignore
323    for k, v in zip(args, args):
324        if k == "delimiter":
325            delimiter = v
326
327    try:
328        import csv as csv_
329
330        yield csv_.reader(file, delimiter=delimiter)
331    finally:
332        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
335def find_new_name(taken: t.Collection[str], base: str) -> str:
336    """
337    Searches for a new name.
338
339    Args:
340        taken: A collection of taken names.
341        base: Base name to alter.
342
343    Returns:
344        The new, available name.
345    """
346    if base not in taken:
347        return base
348
349    i = 2
350    new = f"{base}_{i}"
351    while new in taken:
352        i += 1
353        new = f"{base}_{i}"
354
355    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
358def is_int(text: str) -> bool:
359    return is_type(text, int)
def is_float(text: str) -> bool:
362def is_float(text: str) -> bool:
363    return is_type(text, float)
def is_type(text: str, target_type: Type) -> bool:
366def is_type(text: str, target_type: t.Type) -> bool:
367    try:
368        target_type(text)
369        return True
370    except ValueError:
371        return False
def name_sequence(prefix: str) -> Callable[[], str]:
374def name_sequence(prefix: str) -> t.Callable[[], str]:
375    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
376    sequence = count()
377    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
380def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
381    """Returns a dictionary created from an object's attributes."""
382    return {
383        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
384        **kwargs,
385    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
388def split_num_words(
389    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
390) -> t.List[t.Optional[str]]:
391    """
392    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
393
394    Args:
395        value: The value to be split.
396        sep: The value to use to split on.
397        min_num_words: The minimum number of words that are going to be in the result.
398        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
399
400    Examples:
401        >>> split_num_words("db.table", ".", 3)
402        [None, 'db', 'table']
403        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
404        ['db', 'table', None]
405        >>> split_num_words("db.table", ".", 1)
406        ['db', 'table']
407
408    Returns:
409        The list of words returned by `split`, possibly augmented by a number of `None` values.
410    """
411    words = value.split(sep)
412    if fill_from_start:
413        return [None] * (min_num_words - len(words)) + words
414    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
417def is_iterable(value: t.Any) -> bool:
418    """
419    Checks if the value is an iterable, excluding the types `str` and `bytes`.
420
421    Examples:
422        >>> is_iterable([1,2])
423        True
424        >>> is_iterable("test")
425        False
426
427    Args:
428        value: The value to check if it is an iterable.
429
430    Returns:
431        A `bool` value indicating if it is an iterable.
432    """
433    from sqlglot import Expression
434
435    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
438def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
439    """
440    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
441    type `str` and `bytes` are not regarded as iterables.
442
443    Examples:
444        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
445        [1, 2, 3, 4, 5, 'bla']
446        >>> list(flatten([1, 2, 3]))
447        [1, 2, 3]
448
449    Args:
450        values: The value to be flattened.
451
452    Yields:
453        Non-iterable elements in `values`.
454    """
455    for value in values:
456        if is_iterable(value):
457            yield from flatten(value)
458        else:
459            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
462def dict_depth(d: t.Dict) -> int:
463    """
464    Get the nesting depth of a dictionary.
465
466    Example:
467        >>> dict_depth(None)
468        0
469        >>> dict_depth({})
470        1
471        >>> dict_depth({"a": "b"})
472        1
473        >>> dict_depth({"a": {}})
474        2
475        >>> dict_depth({"a": {"b": {}}})
476        3
477    """
478    try:
479        return 1 + dict_depth(next(iter(d.values())))
480    except AttributeError:
481        # d doesn't have attribute "values"
482        return 0
483    except StopIteration:
484        # d.values() returns an empty sequence
485        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
488def first(it: t.Iterable[T]) -> T:
489    """Returns the first element from an iterable (useful for sets)."""
490    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def to_bool(value: Union[str, bool, NoneType]) -> Union[str, bool, NoneType]:
493def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
494    if isinstance(value, bool) or value is None:
495        return value
496
497    # Coerce the value to boolean if it matches to the truthy/falsy values below
498    value_lower = value.lower()
499    if value_lower in ("true", "1"):
500        return True
501    if value_lower in ("false", "0"):
502        return False
503
504    return value
def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
507def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
508    """
509    Merges a sequence of ranges, represented as tuples (low, high) whose values
510    belong to some totally-ordered set.
511
512    Example:
513        >>> merge_ranges([(1, 3), (2, 6)])
514        [(1, 6)]
515    """
516    if not ranges:
517        return []
518
519    ranges = sorted(ranges)
520
521    merged = [ranges[0]]
522
523    for start, end in ranges[1:]:
524        last_start, last_end = merged[-1]
525
526        if start <= last_end:
527            merged[-1] = (last_start, max(last_end, end))
528        else:
529            merged.append((start, end))
530
531    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
534def is_iso_date(text: str) -> bool:
535    try:
536        datetime.date.fromisoformat(text)
537        return True
538    except ValueError:
539        return False
def is_iso_datetime(text: str) -> bool:
542def is_iso_datetime(text: str) -> bool:
543    try:
544        datetime.datetime.fromisoformat(text)
545        return True
546    except ValueError:
547        return False
DATE_UNITS = {'day', 'year', 'year_month', 'quarter', 'week', 'month'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
554def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
555    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
562class SingleValuedMapping(t.Mapping[K, V]):
563    """
564    Mapping where all keys return the same value.
565
566    This rigamarole is meant to avoid copying keys, which was originally intended
567    as an optimization while qualifying columns for tables with lots of columns.
568    """
569
570    def __init__(self, keys: t.Collection[K], value: V):
571        self._keys = keys if isinstance(keys, Set) else set(keys)
572        self._value = value
573
574    def __getitem__(self, key: K) -> V:
575        if key in self._keys:
576            return self._value
577        raise KeyError(key)
578
579    def __len__(self) -> int:
580        return len(self._keys)
581
582    def __iter__(self) -> t.Iterator[K]:
583        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
570    def __init__(self, keys: t.Collection[K], value: V):
571        self._keys = keys if isinstance(keys, Set) else set(keys)
572        self._value = value