Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.expressions import Expression
 19
 20
 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 22PYTHON_VERSION = sys.version_info[:2]
 23logger = logging.getLogger("sqlglot")
 24
 25
 26class AutoName(Enum):
 27    """
 28    This is used for creating Enum classes where `auto()` is the string form
 29    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 30
 31    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 32    """
 33
 34    def _generate_next_value_(name, _start, _count, _last_values):
 35        return name
 36
 37
 38class classproperty(property):
 39    """
 40    Similar to a normal property but works for class methods
 41    """
 42
 43    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 44        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 45
 46
 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 48    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 49    try:
 50        return seq[index]
 51    except IndexError:
 52        return None
 53
 54
 55@t.overload
 56def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 57
 58
 59@t.overload
 60def ensure_list(value: None) -> t.List: ...
 61
 62
 63@t.overload
 64def ensure_list(value: T) -> t.List[T]: ...
 65
 66
 67def ensure_list(value):
 68    """
 69    Ensures that a value is a list, otherwise casts or wraps it into one.
 70
 71    Args:
 72        value: The value of interest.
 73
 74    Returns:
 75        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 76    """
 77    if value is None:
 78        return []
 79    if isinstance(value, (list, tuple)):
 80        return list(value)
 81
 82    return [value]
 83
 84
 85@t.overload
 86def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
 87
 88
 89@t.overload
 90def ensure_collection(value: T) -> t.Collection[T]: ...
 91
 92
 93def ensure_collection(value):
 94    """
 95    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 96
 97    Args:
 98        value: The value of interest.
 99
100    Returns:
101        The value if it's a collection, or else the value wrapped in a list.
102    """
103    if value is None:
104        return []
105    return (
106        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
107    )
108
109
110def csv(*args: str, sep: str = ", ") -> str:
111    """
112    Formats any number of string arguments as CSV.
113
114    Args:
115        args: The string arguments to format.
116        sep: The argument separator.
117
118    Returns:
119        The arguments formatted as a CSV string.
120    """
121    return sep.join(arg for arg in args if arg)
122
123
124def subclasses(
125    module_name: str,
126    classes: t.Type | t.Tuple[t.Type, ...],
127    exclude: t.Type | t.Tuple[t.Type, ...] = (),
128) -> t.List[t.Type]:
129    """
130    Returns all subclasses for a collection of classes, possibly excluding some of them.
131
132    Args:
133        module_name: The name of the module to search for subclasses in.
134        classes: Class(es) we want to find the subclasses of.
135        exclude: Class(es) we want to exclude from the returned list.
136
137    Returns:
138        The target subclasses.
139    """
140    return [
141        obj
142        for _, obj in inspect.getmembers(
143            sys.modules[module_name],
144            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
145        )
146    ]
147
148
149def apply_index_offset(
150    this: exp.Expression,
151    expressions: t.List[E],
152    offset: int,
153) -> t.List[E]:
154    """
155    Applies an offset to a given integer literal expression.
156
157    Args:
158        this: The target of the index.
159        expressions: The expression the offset will be applied to, wrapped in a list.
160        offset: The offset that will be applied.
161
162    Returns:
163        The original expression with the offset applied to it, wrapped in a list. If the provided
164        `expressions` argument contains more than one expression, it's returned unaffected.
165    """
166    if not offset or len(expressions) != 1:
167        return expressions
168
169    expression = expressions[0]
170
171    from sqlglot import exp
172    from sqlglot.optimizer.annotate_types import annotate_types
173    from sqlglot.optimizer.simplify import simplify
174
175    if not this.type:
176        annotate_types(this)
177
178    if t.cast(exp.DataType, this.type).this not in (
179        exp.DataType.Type.UNKNOWN,
180        exp.DataType.Type.ARRAY,
181    ):
182        return expressions
183
184    if not expression.type:
185        annotate_types(expression)
186
187    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
188        logger.info("Applying array index offset (%s)", offset)
189        expression = simplify(expression + offset)
190        return [expression]
191
192    return expressions
193
194
195def camel_to_snake_case(name: str) -> str:
196    """Converts `name` from camelCase to snake_case and returns the result."""
197    return CAMEL_CASE_PATTERN.sub("_", name).upper()
198
199
200def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
201    """
202    Applies a transformation to a given expression until a fix point is reached.
203
204    Args:
205        expression: The expression to be transformed.
206        func: The transformation to be applied.
207
208    Returns:
209        The transformed expression.
210    """
211    while True:
212        for n in reversed(tuple(expression.walk())):
213            n._hash = hash(n)
214
215        start = hash(expression)
216        expression = func(expression)
217
218        for n in expression.walk():
219            n._hash = None
220        if start == hash(expression):
221            break
222
223    return expression
224
225
226def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
227    """
228    Sorts a given directed acyclic graph in topological order.
229
230    Args:
231        dag: The graph to be sorted.
232
233    Returns:
234        A list that contains all of the graph's nodes in topological order.
235    """
236    result = []
237
238    for node, deps in tuple(dag.items()):
239        for dep in deps:
240            if dep not in dag:
241                dag[dep] = set()
242
243    while dag:
244        current = {node for node, deps in dag.items() if not deps}
245
246        if not current:
247            raise ValueError("Cycle error")
248
249        for node in current:
250            dag.pop(node)
251
252        for deps in dag.values():
253            deps -= current
254
255        result.extend(sorted(current))  # type: ignore
256
257    return result
258
259
260def open_file(file_name: str) -> t.TextIO:
261    """Open a file that may be compressed as gzip and return it in universal newline mode."""
262    with open(file_name, "rb") as f:
263        gzipped = f.read(2) == b"\x1f\x8b"
264
265    if gzipped:
266        import gzip
267
268        return gzip.open(file_name, "rt", newline="")
269
270    return open(file_name, encoding="utf-8", newline="")
271
272
273@contextmanager
274def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
275    """
276    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
277
278    Args:
279        read_csv: A `ReadCSV` function call.
280
281    Yields:
282        A python csv reader.
283    """
284    args = read_csv.expressions
285    file = open_file(read_csv.name)
286
287    delimiter = ","
288    args = iter(arg.name for arg in args)  # type: ignore
289    for k, v in zip(args, args):
290        if k == "delimiter":
291            delimiter = v
292
293    try:
294        import csv as csv_
295
296        yield csv_.reader(file, delimiter=delimiter)
297    finally:
298        file.close()
299
300
301def find_new_name(taken: t.Collection[str], base: str) -> str:
302    """
303    Searches for a new name.
304
305    Args:
306        taken: A collection of taken names.
307        base: Base name to alter.
308
309    Returns:
310        The new, available name.
311    """
312    if base not in taken:
313        return base
314
315    i = 2
316    new = f"{base}_{i}"
317    while new in taken:
318        i += 1
319        new = f"{base}_{i}"
320
321    return new
322
323
324def is_int(text: str) -> bool:
325    return is_type(text, int)
326
327
328def is_float(text: str) -> bool:
329    return is_type(text, float)
330
331
332def is_type(text: str, target_type: t.Type) -> bool:
333    try:
334        target_type(text)
335        return True
336    except ValueError:
337        return False
338
339
340def name_sequence(prefix: str) -> t.Callable[[], str]:
341    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
342    sequence = count()
343    return lambda: f"{prefix}{next(sequence)}"
344
345
346def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
347    """Returns a dictionary created from an object's attributes."""
348    return {
349        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
350        **kwargs,
351    }
352
353
354def split_num_words(
355    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
356) -> t.List[t.Optional[str]]:
357    """
358    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
359
360    Args:
361        value: The value to be split.
362        sep: The value to use to split on.
363        min_num_words: The minimum number of words that are going to be in the result.
364        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
365
366    Examples:
367        >>> split_num_words("db.table", ".", 3)
368        [None, 'db', 'table']
369        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
370        ['db', 'table', None]
371        >>> split_num_words("db.table", ".", 1)
372        ['db', 'table']
373
374    Returns:
375        The list of words returned by `split`, possibly augmented by a number of `None` values.
376    """
377    words = value.split(sep)
378    if fill_from_start:
379        return [None] * (min_num_words - len(words)) + words
380    return words + [None] * (min_num_words - len(words))
381
382
383def is_iterable(value: t.Any) -> bool:
384    """
385    Checks if the value is an iterable, excluding the types `str` and `bytes`.
386
387    Examples:
388        >>> is_iterable([1,2])
389        True
390        >>> is_iterable("test")
391        False
392
393    Args:
394        value: The value to check if it is an iterable.
395
396    Returns:
397        A `bool` value indicating if it is an iterable.
398    """
399    from sqlglot import Expression
400
401    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
402
403
404def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
405    """
406    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
407    type `str` and `bytes` are not regarded as iterables.
408
409    Examples:
410        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
411        [1, 2, 3, 4, 5, 'bla']
412        >>> list(flatten([1, 2, 3]))
413        [1, 2, 3]
414
415    Args:
416        values: The value to be flattened.
417
418    Yields:
419        Non-iterable elements in `values`.
420    """
421    for value in values:
422        if is_iterable(value):
423            yield from flatten(value)
424        else:
425            yield value
426
427
428def dict_depth(d: t.Dict) -> int:
429    """
430    Get the nesting depth of a dictionary.
431
432    Example:
433        >>> dict_depth(None)
434        0
435        >>> dict_depth({})
436        1
437        >>> dict_depth({"a": "b"})
438        1
439        >>> dict_depth({"a": {}})
440        2
441        >>> dict_depth({"a": {"b": {}}})
442        3
443    """
444    try:
445        return 1 + dict_depth(next(iter(d.values())))
446    except AttributeError:
447        # d doesn't have attribute "values"
448        return 0
449    except StopIteration:
450        # d.values() returns an empty sequence
451        return 1
452
453
454def first(it: t.Iterable[T]) -> T:
455    """Returns the first element from an iterable (useful for sets)."""
456    return next(i for i in it)
457
458
459def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
460    if isinstance(value, bool) or value is None:
461        return value
462
463    # Coerce the value to boolean if it matches to the truthy/falsy values below
464    value_lower = value.lower()
465    if value_lower in ("true", "1"):
466        return True
467    if value_lower in ("false", "0"):
468        return False
469
470    return value
471
472
473def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
474    """
475    Merges a sequence of ranges, represented as tuples (low, high) whose values
476    belong to some totally-ordered set.
477
478    Example:
479        >>> merge_ranges([(1, 3), (2, 6)])
480        [(1, 6)]
481    """
482    if not ranges:
483        return []
484
485    ranges = sorted(ranges)
486
487    merged = [ranges[0]]
488
489    for start, end in ranges[1:]:
490        last_start, last_end = merged[-1]
491
492        if start <= last_end:
493            merged[-1] = (last_start, max(last_end, end))
494        else:
495            merged.append((start, end))
496
497    return merged
498
499
500def is_iso_date(text: str) -> bool:
501    try:
502        datetime.date.fromisoformat(text)
503        return True
504    except ValueError:
505        return False
506
507
508def is_iso_datetime(text: str) -> bool:
509    try:
510        datetime.datetime.fromisoformat(text)
511        return True
512    except ValueError:
513        return False
514
515
516# Interval units that operate on date components
517DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
518
519
520def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
521    return expression is not None and expression.name.lower() in DATE_UNITS
522
523
524K = t.TypeVar("K")
525V = t.TypeVar("V")
526
527
528class SingleValuedMapping(t.Mapping[K, V]):
529    """
530    Mapping where all keys return the same value.
531
532    This rigamarole is meant to avoid copying keys, which was originally intended
533    as an optimization while qualifying columns for tables with lots of columns.
534    """
535
536    def __init__(self, keys: t.Collection[K], value: V):
537        self._keys = keys if isinstance(keys, Set) else set(keys)
538        self._value = value
539
540    def __getitem__(self, key: K) -> V:
541        if key in self._keys:
542            return self._value
543        raise KeyError(key)
544
545    def __len__(self) -> int:
546        return len(self._keys)
547
548    def __iter__(self) -> t.Iterator[K]:
549        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
27class AutoName(Enum):
28    """
29    This is used for creating Enum classes where `auto()` is the string form
30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
31
32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
33    """
34
35    def _generate_next_value_(name, _start, _count, _last_values):
36        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

class classproperty(builtins.property):
39class classproperty(property):
40    """
41    Similar to a normal property but works for class methods
42    """
43
44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
50    try:
51        return seq[index]
52    except IndexError:
53        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
68def ensure_list(value):
69    """
70    Ensures that a value is a list, otherwise casts or wraps it into one.
71
72    Args:
73        value: The value of interest.
74
75    Returns:
76        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
77    """
78    if value is None:
79        return []
80    if isinstance(value, (list, tuple)):
81        return list(value)
82
83    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 94def ensure_collection(value):
 95    """
 96    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 97
 98    Args:
 99        value: The value of interest.
100
101    Returns:
102        The value if it's a collection, or else the value wrapped in a list.
103    """
104    if value is None:
105        return []
106    return (
107        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
108    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
111def csv(*args: str, sep: str = ", ") -> str:
112    """
113    Formats any number of string arguments as CSV.
114
115    Args:
116        args: The string arguments to format.
117        sep: The argument separator.
118
119    Returns:
120        The arguments formatted as a CSV string.
121    """
122    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
125def subclasses(
126    module_name: str,
127    classes: t.Type | t.Tuple[t.Type, ...],
128    exclude: t.Type | t.Tuple[t.Type, ...] = (),
129) -> t.List[t.Type]:
130    """
131    Returns all subclasses for a collection of classes, possibly excluding some of them.
132
133    Args:
134        module_name: The name of the module to search for subclasses in.
135        classes: Class(es) we want to find the subclasses of.
136        exclude: Class(es) we want to exclude from the returned list.
137
138    Returns:
139        The target subclasses.
140    """
141    return [
142        obj
143        for _, obj in inspect.getmembers(
144            sys.modules[module_name],
145            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
146        )
147    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
150def apply_index_offset(
151    this: exp.Expression,
152    expressions: t.List[E],
153    offset: int,
154) -> t.List[E]:
155    """
156    Applies an offset to a given integer literal expression.
157
158    Args:
159        this: The target of the index.
160        expressions: The expression the offset will be applied to, wrapped in a list.
161        offset: The offset that will be applied.
162
163    Returns:
164        The original expression with the offset applied to it, wrapped in a list. If the provided
165        `expressions` argument contains more than one expression, it's returned unaffected.
166    """
167    if not offset or len(expressions) != 1:
168        return expressions
169
170    expression = expressions[0]
171
172    from sqlglot import exp
173    from sqlglot.optimizer.annotate_types import annotate_types
174    from sqlglot.optimizer.simplify import simplify
175
176    if not this.type:
177        annotate_types(this)
178
179    if t.cast(exp.DataType, this.type).this not in (
180        exp.DataType.Type.UNKNOWN,
181        exp.DataType.Type.ARRAY,
182    ):
183        return expressions
184
185    if not expression.type:
186        annotate_types(expression)
187
188    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
189        logger.info("Applying array index offset (%s)", offset)
190        expression = simplify(expression + offset)
191        return [expression]
192
193    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
196def camel_to_snake_case(name: str) -> str:
197    """Converts `name` from camelCase to snake_case and returns the result."""
198    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
201def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
202    """
203    Applies a transformation to a given expression until a fix point is reached.
204
205    Args:
206        expression: The expression to be transformed.
207        func: The transformation to be applied.
208
209    Returns:
210        The transformed expression.
211    """
212    while True:
213        for n in reversed(tuple(expression.walk())):
214            n._hash = hash(n)
215
216        start = hash(expression)
217        expression = func(expression)
218
219        for n in expression.walk():
220            n._hash = None
221        if start == hash(expression):
222            break
223
224    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
227def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
228    """
229    Sorts a given directed acyclic graph in topological order.
230
231    Args:
232        dag: The graph to be sorted.
233
234    Returns:
235        A list that contains all of the graph's nodes in topological order.
236    """
237    result = []
238
239    for node, deps in tuple(dag.items()):
240        for dep in deps:
241            if dep not in dag:
242                dag[dep] = set()
243
244    while dag:
245        current = {node for node, deps in dag.items() if not deps}
246
247        if not current:
248            raise ValueError("Cycle error")
249
250        for node in current:
251            dag.pop(node)
252
253        for deps in dag.values():
254            deps -= current
255
256        result.extend(sorted(current))  # type: ignore
257
258    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
261def open_file(file_name: str) -> t.TextIO:
262    """Open a file that may be compressed as gzip and return it in universal newline mode."""
263    with open(file_name, "rb") as f:
264        gzipped = f.read(2) == b"\x1f\x8b"
265
266    if gzipped:
267        import gzip
268
269        return gzip.open(file_name, "rt", newline="")
270
271    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
274@contextmanager
275def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
276    """
277    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
278
279    Args:
280        read_csv: A `ReadCSV` function call.
281
282    Yields:
283        A python csv reader.
284    """
285    args = read_csv.expressions
286    file = open_file(read_csv.name)
287
288    delimiter = ","
289    args = iter(arg.name for arg in args)  # type: ignore
290    for k, v in zip(args, args):
291        if k == "delimiter":
292            delimiter = v
293
294    try:
295        import csv as csv_
296
297        yield csv_.reader(file, delimiter=delimiter)
298    finally:
299        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
302def find_new_name(taken: t.Collection[str], base: str) -> str:
303    """
304    Searches for a new name.
305
306    Args:
307        taken: A collection of taken names.
308        base: Base name to alter.
309
310    Returns:
311        The new, available name.
312    """
313    if base not in taken:
314        return base
315
316    i = 2
317    new = f"{base}_{i}"
318    while new in taken:
319        i += 1
320        new = f"{base}_{i}"
321
322    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
325def is_int(text: str) -> bool:
326    return is_type(text, int)
def is_float(text: str) -> bool:
329def is_float(text: str) -> bool:
330    return is_type(text, float)
def is_type(text: str, target_type: Type) -> bool:
333def is_type(text: str, target_type: t.Type) -> bool:
334    try:
335        target_type(text)
336        return True
337    except ValueError:
338        return False
def name_sequence(prefix: str) -> Callable[[], str]:
341def name_sequence(prefix: str) -> t.Callable[[], str]:
342    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
343    sequence = count()
344    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
347def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
348    """Returns a dictionary created from an object's attributes."""
349    return {
350        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
351        **kwargs,
352    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
355def split_num_words(
356    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
357) -> t.List[t.Optional[str]]:
358    """
359    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
360
361    Args:
362        value: The value to be split.
363        sep: The value to use to split on.
364        min_num_words: The minimum number of words that are going to be in the result.
365        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
366
367    Examples:
368        >>> split_num_words("db.table", ".", 3)
369        [None, 'db', 'table']
370        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
371        ['db', 'table', None]
372        >>> split_num_words("db.table", ".", 1)
373        ['db', 'table']
374
375    Returns:
376        The list of words returned by `split`, possibly augmented by a number of `None` values.
377    """
378    words = value.split(sep)
379    if fill_from_start:
380        return [None] * (min_num_words - len(words)) + words
381    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
384def is_iterable(value: t.Any) -> bool:
385    """
386    Checks if the value is an iterable, excluding the types `str` and `bytes`.
387
388    Examples:
389        >>> is_iterable([1,2])
390        True
391        >>> is_iterable("test")
392        False
393
394    Args:
395        value: The value to check if it is an iterable.
396
397    Returns:
398        A `bool` value indicating if it is an iterable.
399    """
400    from sqlglot import Expression
401
402    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
405def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
406    """
407    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
408    type `str` and `bytes` are not regarded as iterables.
409
410    Examples:
411        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
412        [1, 2, 3, 4, 5, 'bla']
413        >>> list(flatten([1, 2, 3]))
414        [1, 2, 3]
415
416    Args:
417        values: The value to be flattened.
418
419    Yields:
420        Non-iterable elements in `values`.
421    """
422    for value in values:
423        if is_iterable(value):
424            yield from flatten(value)
425        else:
426            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
429def dict_depth(d: t.Dict) -> int:
430    """
431    Get the nesting depth of a dictionary.
432
433    Example:
434        >>> dict_depth(None)
435        0
436        >>> dict_depth({})
437        1
438        >>> dict_depth({"a": "b"})
439        1
440        >>> dict_depth({"a": {}})
441        2
442        >>> dict_depth({"a": {"b": {}}})
443        3
444    """
445    try:
446        return 1 + dict_depth(next(iter(d.values())))
447    except AttributeError:
448        # d doesn't have attribute "values"
449        return 0
450    except StopIteration:
451        # d.values() returns an empty sequence
452        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
455def first(it: t.Iterable[T]) -> T:
456    """Returns the first element from an iterable (useful for sets)."""
457    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def to_bool(value: Union[str, bool, NoneType]) -> Union[str, bool, NoneType]:
460def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]:
461    if isinstance(value, bool) or value is None:
462        return value
463
464    # Coerce the value to boolean if it matches to the truthy/falsy values below
465    value_lower = value.lower()
466    if value_lower in ("true", "1"):
467        return True
468    if value_lower in ("false", "0"):
469        return False
470
471    return value
def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
474def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
475    """
476    Merges a sequence of ranges, represented as tuples (low, high) whose values
477    belong to some totally-ordered set.
478
479    Example:
480        >>> merge_ranges([(1, 3), (2, 6)])
481        [(1, 6)]
482    """
483    if not ranges:
484        return []
485
486    ranges = sorted(ranges)
487
488    merged = [ranges[0]]
489
490    for start, end in ranges[1:]:
491        last_start, last_end = merged[-1]
492
493        if start <= last_end:
494            merged[-1] = (last_start, max(last_end, end))
495        else:
496            merged.append((start, end))
497
498    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
501def is_iso_date(text: str) -> bool:
502    try:
503        datetime.date.fromisoformat(text)
504        return True
505    except ValueError:
506        return False
def is_iso_datetime(text: str) -> bool:
509def is_iso_datetime(text: str) -> bool:
510    try:
511        datetime.datetime.fromisoformat(text)
512        return True
513    except ValueError:
514        return False
DATE_UNITS = {'year_month', 'week', 'month', 'year', 'day', 'quarter'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
521def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
522    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
529class SingleValuedMapping(t.Mapping[K, V]):
530    """
531    Mapping where all keys return the same value.
532
533    This rigamarole is meant to avoid copying keys, which was originally intended
534    as an optimization while qualifying columns for tables with lots of columns.
535    """
536
537    def __init__(self, keys: t.Collection[K], value: V):
538        self._keys = keys if isinstance(keys, Set) else set(keys)
539        self._value = value
540
541    def __getitem__(self, key: K) -> V:
542        if key in self._keys:
543            return self._value
544        raise KeyError(key)
545
546    def __len__(self) -> int:
547        return len(self._keys)
548
549    def __iter__(self) -> t.Iterator[K]:
550        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
537    def __init__(self, keys: t.Collection[K], value: V):
538        self._keys = keys if isinstance(keys, Set) else set(keys)
539        self._value = value