Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.expressions import Expression
 19
 20
 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 22PYTHON_VERSION = sys.version_info[:2]
 23logger = logging.getLogger("sqlglot")
 24
 25
 26class AutoName(Enum):
 27    """
 28    This is used for creating Enum classes where `auto()` is the string form
 29    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 30
 31    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 32    """
 33
 34    def _generate_next_value_(name, _start, _count, _last_values):
 35        return name
 36
 37
 38class classproperty(property):
 39    """
 40    Similar to a normal property but works for class methods
 41    """
 42
 43    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 44        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 45
 46
 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 48    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 49    try:
 50        return seq[index]
 51    except IndexError:
 52        return None
 53
 54
 55@t.overload
 56def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 57
 58
 59@t.overload
 60def ensure_list(value: T) -> t.List[T]: ...
 61
 62
 63def ensure_list(value):
 64    """
 65    Ensures that a value is a list, otherwise casts or wraps it into one.
 66
 67    Args:
 68        value: The value of interest.
 69
 70    Returns:
 71        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 72    """
 73    if value is None:
 74        return []
 75    if isinstance(value, (list, tuple)):
 76        return list(value)
 77
 78    return [value]
 79
 80
 81@t.overload
 82def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
 83
 84
 85@t.overload
 86def ensure_collection(value: T) -> t.Collection[T]: ...
 87
 88
 89def ensure_collection(value):
 90    """
 91    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 92
 93    Args:
 94        value: The value of interest.
 95
 96    Returns:
 97        The value if it's a collection, or else the value wrapped in a list.
 98    """
 99    if value is None:
100        return []
101    return (
102        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
103    )
104
105
106def csv(*args: str, sep: str = ", ") -> str:
107    """
108    Formats any number of string arguments as CSV.
109
110    Args:
111        args: The string arguments to format.
112        sep: The argument separator.
113
114    Returns:
115        The arguments formatted as a CSV string.
116    """
117    return sep.join(arg for arg in args if arg)
118
119
120def subclasses(
121    module_name: str,
122    classes: t.Type | t.Tuple[t.Type, ...],
123    exclude: t.Type | t.Tuple[t.Type, ...] = (),
124) -> t.List[t.Type]:
125    """
126    Returns all subclasses for a collection of classes, possibly excluding some of them.
127
128    Args:
129        module_name: The name of the module to search for subclasses in.
130        classes: Class(es) we want to find the subclasses of.
131        exclude: Class(es) we want to exclude from the returned list.
132
133    Returns:
134        The target subclasses.
135    """
136    return [
137        obj
138        for _, obj in inspect.getmembers(
139            sys.modules[module_name],
140            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
141        )
142    ]
143
144
145def apply_index_offset(
146    this: exp.Expression,
147    expressions: t.List[E],
148    offset: int,
149) -> t.List[E]:
150    """
151    Applies an offset to a given integer literal expression.
152
153    Args:
154        this: The target of the index.
155        expressions: The expression the offset will be applied to, wrapped in a list.
156        offset: The offset that will be applied.
157
158    Returns:
159        The original expression with the offset applied to it, wrapped in a list. If the provided
160        `expressions` argument contains more than one expression, it's returned unaffected.
161    """
162    if not offset or len(expressions) != 1:
163        return expressions
164
165    expression = expressions[0]
166
167    from sqlglot import exp
168    from sqlglot.optimizer.annotate_types import annotate_types
169    from sqlglot.optimizer.simplify import simplify
170
171    if not this.type:
172        annotate_types(this)
173
174    if t.cast(exp.DataType, this.type).this not in (
175        exp.DataType.Type.UNKNOWN,
176        exp.DataType.Type.ARRAY,
177    ):
178        return expressions
179
180    if not expression.type:
181        annotate_types(expression)
182    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
183        logger.warning("Applying array index offset (%s)", offset)
184        expression = simplify(expression + offset)
185        return [expression]
186
187    return expressions
188
189
190def camel_to_snake_case(name: str) -> str:
191    """Converts `name` from camelCase to snake_case and returns the result."""
192    return CAMEL_CASE_PATTERN.sub("_", name).upper()
193
194
195def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
196    """
197    Applies a transformation to a given expression until a fix point is reached.
198
199    Args:
200        expression: The expression to be transformed.
201        func: The transformation to be applied.
202
203    Returns:
204        The transformed expression.
205    """
206    while True:
207        for n in reversed(tuple(expression.walk())):
208            n._hash = hash(n)
209
210        start = hash(expression)
211        expression = func(expression)
212
213        for n in expression.walk():
214            n._hash = None
215        if start == hash(expression):
216            break
217
218    return expression
219
220
221def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
222    """
223    Sorts a given directed acyclic graph in topological order.
224
225    Args:
226        dag: The graph to be sorted.
227
228    Returns:
229        A list that contains all of the graph's nodes in topological order.
230    """
231    result = []
232
233    for node, deps in tuple(dag.items()):
234        for dep in deps:
235            if dep not in dag:
236                dag[dep] = set()
237
238    while dag:
239        current = {node for node, deps in dag.items() if not deps}
240
241        if not current:
242            raise ValueError("Cycle error")
243
244        for node in current:
245            dag.pop(node)
246
247        for deps in dag.values():
248            deps -= current
249
250        result.extend(sorted(current))  # type: ignore
251
252    return result
253
254
255def open_file(file_name: str) -> t.TextIO:
256    """Open a file that may be compressed as gzip and return it in universal newline mode."""
257    with open(file_name, "rb") as f:
258        gzipped = f.read(2) == b"\x1f\x8b"
259
260    if gzipped:
261        import gzip
262
263        return gzip.open(file_name, "rt", newline="")
264
265    return open(file_name, encoding="utf-8", newline="")
266
267
268@contextmanager
269def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
270    """
271    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
272
273    Args:
274        read_csv: A `ReadCSV` function call.
275
276    Yields:
277        A python csv reader.
278    """
279    args = read_csv.expressions
280    file = open_file(read_csv.name)
281
282    delimiter = ","
283    args = iter(arg.name for arg in args)  # type: ignore
284    for k, v in zip(args, args):
285        if k == "delimiter":
286            delimiter = v
287
288    try:
289        import csv as csv_
290
291        yield csv_.reader(file, delimiter=delimiter)
292    finally:
293        file.close()
294
295
296def find_new_name(taken: t.Collection[str], base: str) -> str:
297    """
298    Searches for a new name.
299
300    Args:
301        taken: A collection of taken names.
302        base: Base name to alter.
303
304    Returns:
305        The new, available name.
306    """
307    if base not in taken:
308        return base
309
310    i = 2
311    new = f"{base}_{i}"
312    while new in taken:
313        i += 1
314        new = f"{base}_{i}"
315
316    return new
317
318
319def is_int(text: str) -> bool:
320    return is_type(text, int)
321
322
323def is_float(text: str) -> bool:
324    return is_type(text, float)
325
326
327def is_type(text: str, target_type: t.Type) -> bool:
328    try:
329        target_type(text)
330        return True
331    except ValueError:
332        return False
333
334
335def name_sequence(prefix: str) -> t.Callable[[], str]:
336    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
337    sequence = count()
338    return lambda: f"{prefix}{next(sequence)}"
339
340
341def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
342    """Returns a dictionary created from an object's attributes."""
343    return {
344        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
345        **kwargs,
346    }
347
348
349def split_num_words(
350    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
351) -> t.List[t.Optional[str]]:
352    """
353    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
354
355    Args:
356        value: The value to be split.
357        sep: The value to use to split on.
358        min_num_words: The minimum number of words that are going to be in the result.
359        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
360
361    Examples:
362        >>> split_num_words("db.table", ".", 3)
363        [None, 'db', 'table']
364        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
365        ['db', 'table', None]
366        >>> split_num_words("db.table", ".", 1)
367        ['db', 'table']
368
369    Returns:
370        The list of words returned by `split`, possibly augmented by a number of `None` values.
371    """
372    words = value.split(sep)
373    if fill_from_start:
374        return [None] * (min_num_words - len(words)) + words
375    return words + [None] * (min_num_words - len(words))
376
377
378def is_iterable(value: t.Any) -> bool:
379    """
380    Checks if the value is an iterable, excluding the types `str` and `bytes`.
381
382    Examples:
383        >>> is_iterable([1,2])
384        True
385        >>> is_iterable("test")
386        False
387
388    Args:
389        value: The value to check if it is an iterable.
390
391    Returns:
392        A `bool` value indicating if it is an iterable.
393    """
394    from sqlglot import Expression
395
396    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
397
398
399def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
400    """
401    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
402    type `str` and `bytes` are not regarded as iterables.
403
404    Examples:
405        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
406        [1, 2, 3, 4, 5, 'bla']
407        >>> list(flatten([1, 2, 3]))
408        [1, 2, 3]
409
410    Args:
411        values: The value to be flattened.
412
413    Yields:
414        Non-iterable elements in `values`.
415    """
416    for value in values:
417        if is_iterable(value):
418            yield from flatten(value)
419        else:
420            yield value
421
422
423def dict_depth(d: t.Dict) -> int:
424    """
425    Get the nesting depth of a dictionary.
426
427    Example:
428        >>> dict_depth(None)
429        0
430        >>> dict_depth({})
431        1
432        >>> dict_depth({"a": "b"})
433        1
434        >>> dict_depth({"a": {}})
435        2
436        >>> dict_depth({"a": {"b": {}}})
437        3
438    """
439    try:
440        return 1 + dict_depth(next(iter(d.values())))
441    except AttributeError:
442        # d doesn't have attribute "values"
443        return 0
444    except StopIteration:
445        # d.values() returns an empty sequence
446        return 1
447
448
449def first(it: t.Iterable[T]) -> T:
450    """Returns the first element from an iterable (useful for sets)."""
451    return next(i for i in it)
452
453
454def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
455    """
456    Merges a sequence of ranges, represented as tuples (low, high) whose values
457    belong to some totally-ordered set.
458
459    Example:
460        >>> merge_ranges([(1, 3), (2, 6)])
461        [(1, 6)]
462    """
463    if not ranges:
464        return []
465
466    ranges = sorted(ranges)
467
468    merged = [ranges[0]]
469
470    for start, end in ranges[1:]:
471        last_start, last_end = merged[-1]
472
473        if start <= last_end:
474            merged[-1] = (last_start, max(last_end, end))
475        else:
476            merged.append((start, end))
477
478    return merged
479
480
481def is_iso_date(text: str) -> bool:
482    try:
483        datetime.date.fromisoformat(text)
484        return True
485    except ValueError:
486        return False
487
488
489def is_iso_datetime(text: str) -> bool:
490    try:
491        datetime.datetime.fromisoformat(text)
492        return True
493    except ValueError:
494        return False
495
496
497# Interval units that operate on date components
498DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
499
500
501def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
502    return expression is not None and expression.name.lower() in DATE_UNITS
503
504
505K = t.TypeVar("K")
506V = t.TypeVar("V")
507
508
509class SingleValuedMapping(t.Mapping[K, V]):
510    """
511    Mapping where all keys return the same value.
512
513    This rigamarole is meant to avoid copying keys, which was originally intended
514    as an optimization while qualifying columns for tables with lots of columns.
515    """
516
517    def __init__(self, keys: t.Collection[K], value: V):
518        self._keys = keys if isinstance(keys, Set) else set(keys)
519        self._value = value
520
521    def __getitem__(self, key: K) -> V:
522        if key in self._keys:
523            return self._value
524        raise KeyError(key)
525
526    def __len__(self) -> int:
527        return len(self._keys)
528
529    def __iter__(self) -> t.Iterator[K]:
530        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
27class AutoName(Enum):
28    """
29    This is used for creating Enum classes where `auto()` is the string form
30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
31
32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
33    """
34
35    def _generate_next_value_(name, _start, _count, _last_values):
36        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
class classproperty(builtins.property):
39class classproperty(property):
40    """
41    Similar to a normal property but works for class methods
42    """
43
44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

Inherited Members
builtins.property
property
getter
setter
deleter
fget
fset
fdel
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
50    try:
51        return seq[index]
52    except IndexError:
53        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
64def ensure_list(value):
65    """
66    Ensures that a value is a list, otherwise casts or wraps it into one.
67
68    Args:
69        value: The value of interest.
70
71    Returns:
72        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
73    """
74    if value is None:
75        return []
76    if isinstance(value, (list, tuple)):
77        return list(value)
78
79    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 90def ensure_collection(value):
 91    """
 92    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 93
 94    Args:
 95        value: The value of interest.
 96
 97    Returns:
 98        The value if it's a collection, or else the value wrapped in a list.
 99    """
100    if value is None:
101        return []
102    return (
103        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
104    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
107def csv(*args: str, sep: str = ", ") -> str:
108    """
109    Formats any number of string arguments as CSV.
110
111    Args:
112        args: The string arguments to format.
113        sep: The argument separator.
114
115    Returns:
116        The arguments formatted as a CSV string.
117    """
118    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
121def subclasses(
122    module_name: str,
123    classes: t.Type | t.Tuple[t.Type, ...],
124    exclude: t.Type | t.Tuple[t.Type, ...] = (),
125) -> t.List[t.Type]:
126    """
127    Returns all subclasses for a collection of classes, possibly excluding some of them.
128
129    Args:
130        module_name: The name of the module to search for subclasses in.
131        classes: Class(es) we want to find the subclasses of.
132        exclude: Class(es) we want to exclude from the returned list.
133
134    Returns:
135        The target subclasses.
136    """
137    return [
138        obj
139        for _, obj in inspect.getmembers(
140            sys.modules[module_name],
141            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
142        )
143    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
146def apply_index_offset(
147    this: exp.Expression,
148    expressions: t.List[E],
149    offset: int,
150) -> t.List[E]:
151    """
152    Applies an offset to a given integer literal expression.
153
154    Args:
155        this: The target of the index.
156        expressions: The expression the offset will be applied to, wrapped in a list.
157        offset: The offset that will be applied.
158
159    Returns:
160        The original expression with the offset applied to it, wrapped in a list. If the provided
161        `expressions` argument contains more than one expression, it's returned unaffected.
162    """
163    if not offset or len(expressions) != 1:
164        return expressions
165
166    expression = expressions[0]
167
168    from sqlglot import exp
169    from sqlglot.optimizer.annotate_types import annotate_types
170    from sqlglot.optimizer.simplify import simplify
171
172    if not this.type:
173        annotate_types(this)
174
175    if t.cast(exp.DataType, this.type).this not in (
176        exp.DataType.Type.UNKNOWN,
177        exp.DataType.Type.ARRAY,
178    ):
179        return expressions
180
181    if not expression.type:
182        annotate_types(expression)
183    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
184        logger.warning("Applying array index offset (%s)", offset)
185        expression = simplify(expression + offset)
186        return [expression]
187
188    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
191def camel_to_snake_case(name: str) -> str:
192    """Converts `name` from camelCase to snake_case and returns the result."""
193    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
196def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
197    """
198    Applies a transformation to a given expression until a fix point is reached.
199
200    Args:
201        expression: The expression to be transformed.
202        func: The transformation to be applied.
203
204    Returns:
205        The transformed expression.
206    """
207    while True:
208        for n in reversed(tuple(expression.walk())):
209            n._hash = hash(n)
210
211        start = hash(expression)
212        expression = func(expression)
213
214        for n in expression.walk():
215            n._hash = None
216        if start == hash(expression):
217            break
218
219    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
222def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
223    """
224    Sorts a given directed acyclic graph in topological order.
225
226    Args:
227        dag: The graph to be sorted.
228
229    Returns:
230        A list that contains all of the graph's nodes in topological order.
231    """
232    result = []
233
234    for node, deps in tuple(dag.items()):
235        for dep in deps:
236            if dep not in dag:
237                dag[dep] = set()
238
239    while dag:
240        current = {node for node, deps in dag.items() if not deps}
241
242        if not current:
243            raise ValueError("Cycle error")
244
245        for node in current:
246            dag.pop(node)
247
248        for deps in dag.values():
249            deps -= current
250
251        result.extend(sorted(current))  # type: ignore
252
253    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
256def open_file(file_name: str) -> t.TextIO:
257    """Open a file that may be compressed as gzip and return it in universal newline mode."""
258    with open(file_name, "rb") as f:
259        gzipped = f.read(2) == b"\x1f\x8b"
260
261    if gzipped:
262        import gzip
263
264        return gzip.open(file_name, "rt", newline="")
265
266    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
269@contextmanager
270def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
271    """
272    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
273
274    Args:
275        read_csv: A `ReadCSV` function call.
276
277    Yields:
278        A python csv reader.
279    """
280    args = read_csv.expressions
281    file = open_file(read_csv.name)
282
283    delimiter = ","
284    args = iter(arg.name for arg in args)  # type: ignore
285    for k, v in zip(args, args):
286        if k == "delimiter":
287            delimiter = v
288
289    try:
290        import csv as csv_
291
292        yield csv_.reader(file, delimiter=delimiter)
293    finally:
294        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
297def find_new_name(taken: t.Collection[str], base: str) -> str:
298    """
299    Searches for a new name.
300
301    Args:
302        taken: A collection of taken names.
303        base: Base name to alter.
304
305    Returns:
306        The new, available name.
307    """
308    if base not in taken:
309        return base
310
311    i = 2
312    new = f"{base}_{i}"
313    while new in taken:
314        i += 1
315        new = f"{base}_{i}"
316
317    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
320def is_int(text: str) -> bool:
321    return is_type(text, int)
def is_float(text: str) -> bool:
324def is_float(text: str) -> bool:
325    return is_type(text, float)
def is_type(text: str, target_type: Type) -> bool:
328def is_type(text: str, target_type: t.Type) -> bool:
329    try:
330        target_type(text)
331        return True
332    except ValueError:
333        return False
def name_sequence(prefix: str) -> Callable[[], str]:
336def name_sequence(prefix: str) -> t.Callable[[], str]:
337    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
338    sequence = count()
339    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
342def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
343    """Returns a dictionary created from an object's attributes."""
344    return {
345        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
346        **kwargs,
347    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
350def split_num_words(
351    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
352) -> t.List[t.Optional[str]]:
353    """
354    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
355
356    Args:
357        value: The value to be split.
358        sep: The value to use to split on.
359        min_num_words: The minimum number of words that are going to be in the result.
360        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
361
362    Examples:
363        >>> split_num_words("db.table", ".", 3)
364        [None, 'db', 'table']
365        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
366        ['db', 'table', None]
367        >>> split_num_words("db.table", ".", 1)
368        ['db', 'table']
369
370    Returns:
371        The list of words returned by `split`, possibly augmented by a number of `None` values.
372    """
373    words = value.split(sep)
374    if fill_from_start:
375        return [None] * (min_num_words - len(words)) + words
376    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
379def is_iterable(value: t.Any) -> bool:
380    """
381    Checks if the value is an iterable, excluding the types `str` and `bytes`.
382
383    Examples:
384        >>> is_iterable([1,2])
385        True
386        >>> is_iterable("test")
387        False
388
389    Args:
390        value: The value to check if it is an iterable.
391
392    Returns:
393        A `bool` value indicating if it is an iterable.
394    """
395    from sqlglot import Expression
396
397    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
400def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
401    """
402    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
403    type `str` and `bytes` are not regarded as iterables.
404
405    Examples:
406        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
407        [1, 2, 3, 4, 5, 'bla']
408        >>> list(flatten([1, 2, 3]))
409        [1, 2, 3]
410
411    Args:
412        values: The value to be flattened.
413
414    Yields:
415        Non-iterable elements in `values`.
416    """
417    for value in values:
418        if is_iterable(value):
419            yield from flatten(value)
420        else:
421            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
424def dict_depth(d: t.Dict) -> int:
425    """
426    Get the nesting depth of a dictionary.
427
428    Example:
429        >>> dict_depth(None)
430        0
431        >>> dict_depth({})
432        1
433        >>> dict_depth({"a": "b"})
434        1
435        >>> dict_depth({"a": {}})
436        2
437        >>> dict_depth({"a": {"b": {}}})
438        3
439    """
440    try:
441        return 1 + dict_depth(next(iter(d.values())))
442    except AttributeError:
443        # d doesn't have attribute "values"
444        return 0
445    except StopIteration:
446        # d.values() returns an empty sequence
447        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
450def first(it: t.Iterable[T]) -> T:
451    """Returns the first element from an iterable (useful for sets)."""
452    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
455def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
456    """
457    Merges a sequence of ranges, represented as tuples (low, high) whose values
458    belong to some totally-ordered set.
459
460    Example:
461        >>> merge_ranges([(1, 3), (2, 6)])
462        [(1, 6)]
463    """
464    if not ranges:
465        return []
466
467    ranges = sorted(ranges)
468
469    merged = [ranges[0]]
470
471    for start, end in ranges[1:]:
472        last_start, last_end = merged[-1]
473
474        if start <= last_end:
475            merged[-1] = (last_start, max(last_end, end))
476        else:
477            merged.append((start, end))
478
479    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
482def is_iso_date(text: str) -> bool:
483    try:
484        datetime.date.fromisoformat(text)
485        return True
486    except ValueError:
487        return False
def is_iso_datetime(text: str) -> bool:
490def is_iso_datetime(text: str) -> bool:
491    try:
492        datetime.datetime.fromisoformat(text)
493        return True
494    except ValueError:
495        return False
DATE_UNITS = {'year', 'week', 'day', 'year_month', 'quarter', 'month'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
502def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
503    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
510class SingleValuedMapping(t.Mapping[K, V]):
511    """
512    Mapping where all keys return the same value.
513
514    This rigamarole is meant to avoid copying keys, which was originally intended
515    as an optimization while qualifying columns for tables with lots of columns.
516    """
517
518    def __init__(self, keys: t.Collection[K], value: V):
519        self._keys = keys if isinstance(keys, Set) else set(keys)
520        self._value = value
521
522    def __getitem__(self, key: K) -> V:
523        if key in self._keys:
524            return self._value
525        raise KeyError(key)
526
527    def __len__(self) -> int:
528        return len(self._keys)
529
530    def __iter__(self) -> t.Iterator[K]:
531        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
518    def __init__(self, keys: t.Collection[K], value: V):
519        self._keys = keys if isinstance(keys, Set) else set(keys)
520        self._value = value
Inherited Members
collections.abc.Mapping
get
keys
items
values