Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from difflib import get_close_matches
 13from enum import Enum
 14from itertools import count
 15
 16if t.TYPE_CHECKING:
 17    from sqlglot import exp
 18    from sqlglot._typing import A, E, T
 19    from sqlglot.dialects.dialect import DialectType
 20    from sqlglot.expressions import Expression
 21
 22
 23CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 24PYTHON_VERSION = sys.version_info[:2]
 25logger = logging.getLogger("sqlglot")
 26
 27
 28class AutoName(Enum):
 29    """
 30    This is used for creating Enum classes where `auto()` is the string form
 31    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 32
 33    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 34    """
 35
 36    def _generate_next_value_(name, _start, _count, _last_values):
 37        return name
 38
 39
 40class classproperty(property):
 41    """
 42    Similar to a normal property but works for class methods
 43    """
 44
 45    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 46        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 47
 48
 49def suggest_closest_match_and_fail(
 50    kind: str,
 51    word: str,
 52    possibilities: t.Iterable[str],
 53) -> None:
 54    close_matches = get_close_matches(word, possibilities, n=1)
 55
 56    similar = seq_get(close_matches, 0) or ""
 57    if similar:
 58        similar = f" Did you mean {similar}?"
 59
 60    raise ValueError(f"Unknown {kind} '{word}'.{similar}")
 61
 62
 63def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 64    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 65    try:
 66        return seq[index]
 67    except IndexError:
 68        return None
 69
 70
 71@t.overload
 72def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 73
 74
 75@t.overload
 76def ensure_list(value: None) -> t.List: ...
 77
 78
 79@t.overload
 80def ensure_list(value: T) -> t.List[T]: ...
 81
 82
 83def ensure_list(value):
 84    """
 85    Ensures that a value is a list, otherwise casts or wraps it into one.
 86
 87    Args:
 88        value: The value of interest.
 89
 90    Returns:
 91        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 92    """
 93    if value is None:
 94        return []
 95    if isinstance(value, (list, tuple)):
 96        return list(value)
 97
 98    return [value]
 99
100
101@t.overload
102def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
103
104
105@t.overload
106def ensure_collection(value: T) -> t.Collection[T]: ...
107
108
109def ensure_collection(value):
110    """
111    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
112
113    Args:
114        value: The value of interest.
115
116    Returns:
117        The value if it's a collection, or else the value wrapped in a list.
118    """
119    if value is None:
120        return []
121    return (
122        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
123    )
124
125
126def csv(*args: str, sep: str = ", ") -> str:
127    """
128    Formats any number of string arguments as CSV.
129
130    Args:
131        args: The string arguments to format.
132        sep: The argument separator.
133
134    Returns:
135        The arguments formatted as a CSV string.
136    """
137    return sep.join(arg for arg in args if arg)
138
139
140def subclasses(
141    module_name: str,
142    classes: t.Type | t.Tuple[t.Type, ...],
143    exclude: t.Type | t.Tuple[t.Type, ...] = (),
144) -> t.List[t.Type]:
145    """
146    Returns all subclasses for a collection of classes, possibly excluding some of them.
147
148    Args:
149        module_name: The name of the module to search for subclasses in.
150        classes: Class(es) we want to find the subclasses of.
151        exclude: Class(es) we want to exclude from the returned list.
152
153    Returns:
154        The target subclasses.
155    """
156    return [
157        obj
158        for _, obj in inspect.getmembers(
159            sys.modules[module_name],
160            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
161        )
162    ]
163
164
165def apply_index_offset(
166    this: exp.Expression,
167    expressions: t.List[E],
168    offset: int,
169    dialect: DialectType = None,
170) -> t.List[E]:
171    """
172    Applies an offset to a given integer literal expression.
173
174    Args:
175        this: The target of the index.
176        expressions: The expression the offset will be applied to, wrapped in a list.
177        offset: The offset that will be applied.
178        dialect: the dialect of interest.
179
180    Returns:
181        The original expression with the offset applied to it, wrapped in a list. If the provided
182        `expressions` argument contains more than one expression, it's returned unaffected.
183    """
184    if not offset or len(expressions) != 1:
185        return expressions
186
187    expression = expressions[0]
188
189    from sqlglot import exp
190    from sqlglot.optimizer.annotate_types import annotate_types
191    from sqlglot.optimizer.simplify import simplify
192
193    if not this.type:
194        annotate_types(this, dialect=dialect)
195
196    if t.cast(exp.DataType, this.type).this not in (
197        exp.DataType.Type.UNKNOWN,
198        exp.DataType.Type.ARRAY,
199    ):
200        return expressions
201
202    if not expression.type:
203        annotate_types(expression, dialect=dialect)
204
205    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
206        logger.info("Applying array index offset (%s)", offset)
207        expression = simplify(expression + offset)
208        return [expression]
209
210    return expressions
211
212
213def camel_to_snake_case(name: str) -> str:
214    """Converts `name` from camelCase to snake_case and returns the result."""
215    return CAMEL_CASE_PATTERN.sub("_", name).upper()
216
217
218def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
219    """
220    Applies a transformation to a given expression until a fix point is reached.
221
222    Args:
223        expression: The expression to be transformed.
224        func: The transformation to be applied.
225
226    Returns:
227        The transformed expression.
228    """
229
230    while True:
231        start_hash = hash(expression)
232        expression = func(expression)
233        end_hash = hash(expression)
234
235        if start_hash == end_hash:
236            break
237
238    return expression
239
240
241def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
242    """
243    Sorts a given directed acyclic graph in topological order.
244
245    Args:
246        dag: The graph to be sorted.
247
248    Returns:
249        A list that contains all of the graph's nodes in topological order.
250    """
251    result = []
252
253    for node, deps in tuple(dag.items()):
254        for dep in deps:
255            if dep not in dag:
256                dag[dep] = set()
257
258    while dag:
259        current = {node for node, deps in dag.items() if not deps}
260
261        if not current:
262            raise ValueError("Cycle error")
263
264        for node in current:
265            dag.pop(node)
266
267        for deps in dag.values():
268            deps -= current
269
270        result.extend(sorted(current))  # type: ignore
271
272    return result
273
274
275def open_file(file_name: str) -> t.TextIO:
276    """Open a file that may be compressed as gzip and return it in universal newline mode."""
277    with open(file_name, "rb") as f:
278        gzipped = f.read(2) == b"\x1f\x8b"
279
280    if gzipped:
281        import gzip
282
283        return gzip.open(file_name, "rt", newline="")
284
285    return open(file_name, encoding="utf-8", newline="")
286
287
288@contextmanager
289def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
290    """
291    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
292
293    Args:
294        read_csv: A `ReadCSV` function call.
295
296    Yields:
297        A python csv reader.
298    """
299    args = read_csv.expressions
300    file = open_file(read_csv.name)
301
302    delimiter = ","
303    args = iter(arg.name for arg in args)  # type: ignore
304    for k, v in zip(args, args):
305        if k == "delimiter":
306            delimiter = v
307
308    try:
309        import csv as csv_
310
311        yield csv_.reader(file, delimiter=delimiter)
312    finally:
313        file.close()
314
315
316def find_new_name(taken: t.Collection[str], base: str) -> str:
317    """
318    Searches for a new name.
319
320    Args:
321        taken: A collection of taken names.
322        base: Base name to alter.
323
324    Returns:
325        The new, available name.
326    """
327    if base not in taken:
328        return base
329
330    i = 2
331    new = f"{base}_{i}"
332    while new in taken:
333        i += 1
334        new = f"{base}_{i}"
335
336    return new
337
338
339def is_int(text: str) -> bool:
340    return is_type(text, int)
341
342
343def is_float(text: str) -> bool:
344    return is_type(text, float)
345
346
347def is_type(text: str, target_type: t.Type) -> bool:
348    try:
349        target_type(text)
350        return True
351    except ValueError:
352        return False
353
354
355def name_sequence(prefix: str) -> t.Callable[[], str]:
356    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
357    sequence = count()
358    return lambda: f"{prefix}{next(sequence)}"
359
360
361def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
362    """Returns a dictionary created from an object's attributes."""
363    return {
364        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
365        **kwargs,
366    }
367
368
369def split_num_words(
370    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
371) -> t.List[t.Optional[str]]:
372    """
373    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
374
375    Args:
376        value: The value to be split.
377        sep: The value to use to split on.
378        min_num_words: The minimum number of words that are going to be in the result.
379        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
380
381    Examples:
382        >>> split_num_words("db.table", ".", 3)
383        [None, 'db', 'table']
384        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
385        ['db', 'table', None]
386        >>> split_num_words("db.table", ".", 1)
387        ['db', 'table']
388
389    Returns:
390        The list of words returned by `split`, possibly augmented by a number of `None` values.
391    """
392    words = value.split(sep)
393    if fill_from_start:
394        return [None] * (min_num_words - len(words)) + words
395    return words + [None] * (min_num_words - len(words))
396
397
398def is_iterable(value: t.Any) -> bool:
399    """
400    Checks if the value is an iterable, excluding the types `str` and `bytes`.
401
402    Examples:
403        >>> is_iterable([1,2])
404        True
405        >>> is_iterable("test")
406        False
407
408    Args:
409        value: The value to check if it is an iterable.
410
411    Returns:
412        A `bool` value indicating if it is an iterable.
413    """
414    from sqlglot import Expression
415
416    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
417
418
419def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
420    """
421    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
422    type `str` and `bytes` are not regarded as iterables.
423
424    Examples:
425        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
426        [1, 2, 3, 4, 5, 'bla']
427        >>> list(flatten([1, 2, 3]))
428        [1, 2, 3]
429
430    Args:
431        values: The value to be flattened.
432
433    Yields:
434        Non-iterable elements in `values`.
435    """
436    for value in values:
437        if is_iterable(value):
438            yield from flatten(value)
439        else:
440            yield value
441
442
443def dict_depth(d: t.Dict) -> int:
444    """
445    Get the nesting depth of a dictionary.
446
447    Example:
448        >>> dict_depth(None)
449        0
450        >>> dict_depth({})
451        1
452        >>> dict_depth({"a": "b"})
453        1
454        >>> dict_depth({"a": {}})
455        2
456        >>> dict_depth({"a": {"b": {}}})
457        3
458    """
459    try:
460        return 1 + dict_depth(next(iter(d.values())))
461    except AttributeError:
462        # d doesn't have attribute "values"
463        return 0
464    except StopIteration:
465        # d.values() returns an empty sequence
466        return 1
467
468
469def first(it: t.Iterable[T]) -> T:
470    """Returns the first element from an iterable (useful for sets)."""
471    return next(i for i in it)
472
473
474def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
475    if isinstance(value, bool) or value is None:
476        return value
477
478    # Coerce the value to boolean if it matches to the truthy/falsy values below
479    value_lower = value.lower()
480    if value_lower in ("true", "1"):
481        return True
482    if value_lower in ("false", "0"):
483        return False
484
485    return value
486
487
488def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
489    """
490    Merges a sequence of ranges, represented as tuples (low, high) whose values
491    belong to some totally-ordered set.
492
493    Example:
494        >>> merge_ranges([(1, 3), (2, 6)])
495        [(1, 6)]
496    """
497    if not ranges:
498        return []
499
500    ranges = sorted(ranges)
501
502    merged = [ranges[0]]
503
504    for start, end in ranges[1:]:
505        last_start, last_end = merged[-1]
506
507        if start <= last_end:
508            merged[-1] = (last_start, max(last_end, end))
509        else:
510            merged.append((start, end))
511
512    return merged
513
514
515def is_iso_date(text: str) -> bool:
516    try:
517        datetime.date.fromisoformat(text)
518        return True
519    except ValueError:
520        return False
521
522
523def is_iso_datetime(text: str) -> bool:
524    try:
525        datetime.datetime.fromisoformat(text)
526        return True
527    except ValueError:
528        return False
529
530
531# Interval units that operate on date components
532DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
533
534
535def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
536    return expression is not None and expression.name.lower() in DATE_UNITS
537
538
539K = t.TypeVar("K")
540V = t.TypeVar("V")
541
542
543class SingleValuedMapping(t.Mapping[K, V]):
544    """
545    Mapping where all keys return the same value.
546
547    This rigamarole is meant to avoid copying keys, which was originally intended
548    as an optimization while qualifying columns for tables with lots of columns.
549    """
550
551    def __init__(self, keys: t.Collection[K], value: V):
552        self._keys = keys if isinstance(keys, Set) else set(keys)
553        self._value = value
554
555    def __getitem__(self, key: K) -> V:
556        if key in self._keys:
557            return self._value
558        raise KeyError(key)
559
560    def __len__(self) -> int:
561        return len(self._keys)
562
563    def __iter__(self) -> t.Iterator[K]:
564        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
29class AutoName(Enum):
30    """
31    This is used for creating Enum classes where `auto()` is the string form
32    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
33
34    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
35    """
36
37    def _generate_next_value_(name, _start, _count, _last_values):
38        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

class classproperty(builtins.property):
41class classproperty(property):
42    """
43    Similar to a normal property but works for class methods
44    """
45
46    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
47        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

def suggest_closest_match_and_fail(kind: str, word: str, possibilities: Iterable[str]) -> None:
50def suggest_closest_match_and_fail(
51    kind: str,
52    word: str,
53    possibilities: t.Iterable[str],
54) -> None:
55    close_matches = get_close_matches(word, possibilities, n=1)
56
57    similar = seq_get(close_matches, 0) or ""
58    if similar:
59        similar = f" Did you mean {similar}?"
60
61    raise ValueError(f"Unknown {kind} '{word}'.{similar}")
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
64def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
65    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
66    try:
67        return seq[index]
68    except IndexError:
69        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
84def ensure_list(value):
85    """
86    Ensures that a value is a list, otherwise casts or wraps it into one.
87
88    Args:
89        value: The value of interest.
90
91    Returns:
92        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
93    """
94    if value is None:
95        return []
96    if isinstance(value, (list, tuple)):
97        return list(value)
98
99    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
110def ensure_collection(value):
111    """
112    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
113
114    Args:
115        value: The value of interest.
116
117    Returns:
118        The value if it's a collection, or else the value wrapped in a list.
119    """
120    if value is None:
121        return []
122    return (
123        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
124    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
127def csv(*args: str, sep: str = ", ") -> str:
128    """
129    Formats any number of string arguments as CSV.
130
131    Args:
132        args: The string arguments to format.
133        sep: The argument separator.
134
135    Returns:
136        The arguments formatted as a CSV string.
137    """
138    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
141def subclasses(
142    module_name: str,
143    classes: t.Type | t.Tuple[t.Type, ...],
144    exclude: t.Type | t.Tuple[t.Type, ...] = (),
145) -> t.List[t.Type]:
146    """
147    Returns all subclasses for a collection of classes, possibly excluding some of them.
148
149    Args:
150        module_name: The name of the module to search for subclasses in.
151        classes: Class(es) we want to find the subclasses of.
152        exclude: Class(es) we want to exclude from the returned list.
153
154    Returns:
155        The target subclasses.
156    """
157    return [
158        obj
159        for _, obj in inspect.getmembers(
160            sys.modules[module_name],
161            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
162        )
163    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> List[~E]:
166def apply_index_offset(
167    this: exp.Expression,
168    expressions: t.List[E],
169    offset: int,
170    dialect: DialectType = None,
171) -> t.List[E]:
172    """
173    Applies an offset to a given integer literal expression.
174
175    Args:
176        this: The target of the index.
177        expressions: The expression the offset will be applied to, wrapped in a list.
178        offset: The offset that will be applied.
179        dialect: the dialect of interest.
180
181    Returns:
182        The original expression with the offset applied to it, wrapped in a list. If the provided
183        `expressions` argument contains more than one expression, it's returned unaffected.
184    """
185    if not offset or len(expressions) != 1:
186        return expressions
187
188    expression = expressions[0]
189
190    from sqlglot import exp
191    from sqlglot.optimizer.annotate_types import annotate_types
192    from sqlglot.optimizer.simplify import simplify
193
194    if not this.type:
195        annotate_types(this, dialect=dialect)
196
197    if t.cast(exp.DataType, this.type).this not in (
198        exp.DataType.Type.UNKNOWN,
199        exp.DataType.Type.ARRAY,
200    ):
201        return expressions
202
203    if not expression.type:
204        annotate_types(expression, dialect=dialect)
205
206    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
207        logger.info("Applying array index offset (%s)", offset)
208        expression = simplify(expression + offset)
209        return [expression]
210
211    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
  • dialect: the dialect of interest.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
214def camel_to_snake_case(name: str) -> str:
215    """Converts `name` from camelCase to snake_case and returns the result."""
216    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
219def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
220    """
221    Applies a transformation to a given expression until a fix point is reached.
222
223    Args:
224        expression: The expression to be transformed.
225        func: The transformation to be applied.
226
227    Returns:
228        The transformed expression.
229    """
230
231    while True:
232        start_hash = hash(expression)
233        expression = func(expression)
234        end_hash = hash(expression)
235
236        if start_hash == end_hash:
237            break
238
239    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
242def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
243    """
244    Sorts a given directed acyclic graph in topological order.
245
246    Args:
247        dag: The graph to be sorted.
248
249    Returns:
250        A list that contains all of the graph's nodes in topological order.
251    """
252    result = []
253
254    for node, deps in tuple(dag.items()):
255        for dep in deps:
256            if dep not in dag:
257                dag[dep] = set()
258
259    while dag:
260        current = {node for node, deps in dag.items() if not deps}
261
262        if not current:
263            raise ValueError("Cycle error")
264
265        for node in current:
266            dag.pop(node)
267
268        for deps in dag.values():
269            deps -= current
270
271        result.extend(sorted(current))  # type: ignore
272
273    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
276def open_file(file_name: str) -> t.TextIO:
277    """Open a file that may be compressed as gzip and return it in universal newline mode."""
278    with open(file_name, "rb") as f:
279        gzipped = f.read(2) == b"\x1f\x8b"
280
281    if gzipped:
282        import gzip
283
284        return gzip.open(file_name, "rt", newline="")
285
286    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
289@contextmanager
290def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
291    """
292    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
293
294    Args:
295        read_csv: A `ReadCSV` function call.
296
297    Yields:
298        A python csv reader.
299    """
300    args = read_csv.expressions
301    file = open_file(read_csv.name)
302
303    delimiter = ","
304    args = iter(arg.name for arg in args)  # type: ignore
305    for k, v in zip(args, args):
306        if k == "delimiter":
307            delimiter = v
308
309    try:
310        import csv as csv_
311
312        yield csv_.reader(file, delimiter=delimiter)
313    finally:
314        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
317def find_new_name(taken: t.Collection[str], base: str) -> str:
318    """
319    Searches for a new name.
320
321    Args:
322        taken: A collection of taken names.
323        base: Base name to alter.
324
325    Returns:
326        The new, available name.
327    """
328    if base not in taken:
329        return base
330
331    i = 2
332    new = f"{base}_{i}"
333    while new in taken:
334        i += 1
335        new = f"{base}_{i}"
336
337    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
340def is_int(text: str) -> bool:
341    return is_type(text, int)
def is_float(text: str) -> bool:
344def is_float(text: str) -> bool:
345    return is_type(text, float)
def is_type(text: str, target_type: Type) -> bool:
348def is_type(text: str, target_type: t.Type) -> bool:
349    try:
350        target_type(text)
351        return True
352    except ValueError:
353        return False
def name_sequence(prefix: str) -> Callable[[], str]:
356def name_sequence(prefix: str) -> t.Callable[[], str]:
357    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
358    sequence = count()
359    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
362def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
363    """Returns a dictionary created from an object's attributes."""
364    return {
365        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
366        **kwargs,
367    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
370def split_num_words(
371    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
372) -> t.List[t.Optional[str]]:
373    """
374    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
375
376    Args:
377        value: The value to be split.
378        sep: The value to use to split on.
379        min_num_words: The minimum number of words that are going to be in the result.
380        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
381
382    Examples:
383        >>> split_num_words("db.table", ".", 3)
384        [None, 'db', 'table']
385        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
386        ['db', 'table', None]
387        >>> split_num_words("db.table", ".", 1)
388        ['db', 'table']
389
390    Returns:
391        The list of words returned by `split`, possibly augmented by a number of `None` values.
392    """
393    words = value.split(sep)
394    if fill_from_start:
395        return [None] * (min_num_words - len(words)) + words
396    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
399def is_iterable(value: t.Any) -> bool:
400    """
401    Checks if the value is an iterable, excluding the types `str` and `bytes`.
402
403    Examples:
404        >>> is_iterable([1,2])
405        True
406        >>> is_iterable("test")
407        False
408
409    Args:
410        value: The value to check if it is an iterable.
411
412    Returns:
413        A `bool` value indicating if it is an iterable.
414    """
415    from sqlglot import Expression
416
417    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
420def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
421    """
422    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
423    type `str` and `bytes` are not regarded as iterables.
424
425    Examples:
426        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
427        [1, 2, 3, 4, 5, 'bla']
428        >>> list(flatten([1, 2, 3]))
429        [1, 2, 3]
430
431    Args:
432        values: The value to be flattened.
433
434    Yields:
435        Non-iterable elements in `values`.
436    """
437    for value in values:
438        if is_iterable(value):
439            yield from flatten(value)
440        else:
441            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
444def dict_depth(d: t.Dict) -> int:
445    """
446    Get the nesting depth of a dictionary.
447
448    Example:
449        >>> dict_depth(None)
450        0
451        >>> dict_depth({})
452        1
453        >>> dict_depth({"a": "b"})
454        1
455        >>> dict_depth({"a": {}})
456        2
457        >>> dict_depth({"a": {"b": {}}})
458        3
459    """
460    try:
461        return 1 + dict_depth(next(iter(d.values())))
462    except AttributeError:
463        # d doesn't have attribute "values"
464        return 0
465    except StopIteration:
466        # d.values() returns an empty sequence
467        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
470def first(it: t.Iterable[T]) -> T:
471    """Returns the first element from an iterable (useful for sets)."""
472    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def to_bool(value: Union[str, bool, NoneType]) -> Union[str, bool, NoneType]:
475def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
476    if isinstance(value, bool) or value is None:
477        return value
478
479    # Coerce the value to boolean if it matches to the truthy/falsy values below
480    value_lower = value.lower()
481    if value_lower in ("true", "1"):
482        return True
483    if value_lower in ("false", "0"):
484        return False
485
486    return value
def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
489def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
490    """
491    Merges a sequence of ranges, represented as tuples (low, high) whose values
492    belong to some totally-ordered set.
493
494    Example:
495        >>> merge_ranges([(1, 3), (2, 6)])
496        [(1, 6)]
497    """
498    if not ranges:
499        return []
500
501    ranges = sorted(ranges)
502
503    merged = [ranges[0]]
504
505    for start, end in ranges[1:]:
506        last_start, last_end = merged[-1]
507
508        if start <= last_end:
509            merged[-1] = (last_start, max(last_end, end))
510        else:
511            merged.append((start, end))
512
513    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
516def is_iso_date(text: str) -> bool:
517    try:
518        datetime.date.fromisoformat(text)
519        return True
520    except ValueError:
521        return False
def is_iso_datetime(text: str) -> bool:
524def is_iso_datetime(text: str) -> bool:
525    try:
526        datetime.datetime.fromisoformat(text)
527        return True
528    except ValueError:
529        return False
DATE_UNITS = {'year', 'month', 'quarter', 'week', 'day', 'year_month'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
536def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
537    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
544class SingleValuedMapping(t.Mapping[K, V]):
545    """
546    Mapping where all keys return the same value.
547
548    This rigamarole is meant to avoid copying keys, which was originally intended
549    as an optimization while qualifying columns for tables with lots of columns.
550    """
551
552    def __init__(self, keys: t.Collection[K], value: V):
553        self._keys = keys if isinstance(keys, Set) else set(keys)
554        self._value = value
555
556    def __getitem__(self, key: K) -> V:
557        if key in self._keys:
558            return self._value
559        raise KeyError(key)
560
561    def __len__(self) -> int:
562        return len(self._keys)
563
564    def __iter__(self) -> t.Iterator[K]:
565        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
552    def __init__(self, keys: t.Collection[K], value: V):
553        self._keys = keys if isinstance(keys, Set) else set(keys)
554        self._value = value