sqlglot.helper
1from __future__ import annotations 2 3import datetime 4import inspect 5import logging 6import re 7import sys 8import typing as t 9from collections.abc import Collection, Set 10from contextlib import contextmanager 11from copy import copy 12from difflib import get_close_matches 13from enum import Enum 14from itertools import count 15 16if t.TYPE_CHECKING: 17 from sqlglot import exp 18 from sqlglot._typing import A, E, T 19 from sqlglot.dialects.dialect import DialectType 20 from sqlglot.expressions import Expression 21 22 23CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") 24PYTHON_VERSION = sys.version_info[:2] 25logger = logging.getLogger("sqlglot") 26 27 28class AutoName(Enum): 29 """ 30 This is used for creating Enum classes where `auto()` is the string form 31 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 32 33 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 34 """ 35 36 def _generate_next_value_(name, _start, _count, _last_values): 37 return name 38 39 40class classproperty(property): 41 """ 42 Similar to a normal property but works for class methods 43 """ 44 45 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 46 return classmethod(self.fget).__get__(None, owner)() # type: ignore 47 48 49def suggest_closest_match_and_fail( 50 kind: str, 51 word: str, 52 possibilities: t.Iterable[str], 53) -> None: 54 close_matches = get_close_matches(word, possibilities, n=1) 55 56 similar = seq_get(close_matches, 0) or "" 57 if similar: 58 similar = f" Did you mean {similar}?" 59 60 raise ValueError(f"Unknown {kind} '{word}'.{similar}") 61 62 63def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 64 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 65 try: 66 return seq[index] 67 except IndexError: 68 return None 69 70 71@t.overload 72def ensure_list(value: t.Collection[T]) -> t.List[T]: ... 73 74 75@t.overload 76def ensure_list(value: None) -> t.List: ... 77 78 79@t.overload 80def ensure_list(value: T) -> t.List[T]: ... 81 82 83def ensure_list(value): 84 """ 85 Ensures that a value is a list, otherwise casts or wraps it into one. 86 87 Args: 88 value: The value of interest. 89 90 Returns: 91 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 92 """ 93 if value is None: 94 return [] 95 if isinstance(value, (list, tuple)): 96 return list(value) 97 98 return [value] 99 100 101@t.overload 102def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... 103 104 105@t.overload 106def ensure_collection(value: T) -> t.Collection[T]: ... 107 108 109def ensure_collection(value): 110 """ 111 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 112 113 Args: 114 value: The value of interest. 115 116 Returns: 117 The value if it's a collection, or else the value wrapped in a list. 118 """ 119 if value is None: 120 return [] 121 return ( 122 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 123 ) 124 125 126def csv(*args: str, sep: str = ", ") -> str: 127 """ 128 Formats any number of string arguments as CSV. 129 130 Args: 131 args: The string arguments to format. 132 sep: The argument separator. 133 134 Returns: 135 The arguments formatted as a CSV string. 136 """ 137 return sep.join(arg for arg in args if arg) 138 139 140def subclasses( 141 module_name: str, 142 classes: t.Type | t.Tuple[t.Type, ...], 143 exclude: t.Type | t.Tuple[t.Type, ...] = (), 144) -> t.List[t.Type]: 145 """ 146 Returns all subclasses for a collection of classes, possibly excluding some of them. 147 148 Args: 149 module_name: The name of the module to search for subclasses in. 150 classes: Class(es) we want to find the subclasses of. 151 exclude: Class(es) we want to exclude from the returned list. 152 153 Returns: 154 The target subclasses. 155 """ 156 return [ 157 obj 158 for _, obj in inspect.getmembers( 159 sys.modules[module_name], 160 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 161 ) 162 ] 163 164 165def apply_index_offset( 166 this: exp.Expression, 167 expressions: t.List[E], 168 offset: int, 169 dialect: DialectType = None, 170) -> t.List[E]: 171 """ 172 Applies an offset to a given integer literal expression. 173 174 Args: 175 this: The target of the index. 176 expressions: The expression the offset will be applied to, wrapped in a list. 177 offset: The offset that will be applied. 178 dialect: the dialect of interest. 179 180 Returns: 181 The original expression with the offset applied to it, wrapped in a list. If the provided 182 `expressions` argument contains more than one expression, it's returned unaffected. 183 """ 184 if not offset or len(expressions) != 1: 185 return expressions 186 187 expression = expressions[0] 188 189 from sqlglot import exp 190 from sqlglot.optimizer.annotate_types import annotate_types 191 from sqlglot.optimizer.simplify import simplify 192 193 if not this.type: 194 annotate_types(this, dialect=dialect) 195 196 if t.cast(exp.DataType, this.type).this not in ( 197 exp.DataType.Type.UNKNOWN, 198 exp.DataType.Type.ARRAY, 199 ): 200 return expressions 201 202 if not expression.type: 203 annotate_types(expression, dialect=dialect) 204 205 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 206 logger.info("Applying array index offset (%s)", offset) 207 expression = simplify(expression + offset) 208 return [expression] 209 210 return expressions 211 212 213def camel_to_snake_case(name: str) -> str: 214 """Converts `name` from camelCase to snake_case and returns the result.""" 215 return CAMEL_CASE_PATTERN.sub("_", name).upper() 216 217 218def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 219 """ 220 Applies a transformation to a given expression until a fix point is reached. 221 222 Args: 223 expression: The expression to be transformed. 224 func: The transformation to be applied. 225 226 Returns: 227 The transformed expression. 228 """ 229 230 while True: 231 start_hash = hash(expression) 232 expression = func(expression) 233 end_hash = hash(expression) 234 235 if start_hash == end_hash: 236 break 237 238 return expression 239 240 241def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 242 """ 243 Sorts a given directed acyclic graph in topological order. 244 245 Args: 246 dag: The graph to be sorted. 247 248 Returns: 249 A list that contains all of the graph's nodes in topological order. 250 """ 251 result = [] 252 253 for node, deps in tuple(dag.items()): 254 for dep in deps: 255 if dep not in dag: 256 dag[dep] = set() 257 258 while dag: 259 current = {node for node, deps in dag.items() if not deps} 260 261 if not current: 262 raise ValueError("Cycle error") 263 264 for node in current: 265 dag.pop(node) 266 267 for deps in dag.values(): 268 deps -= current 269 270 result.extend(sorted(current)) # type: ignore 271 272 return result 273 274 275def open_file(file_name: str) -> t.TextIO: 276 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 277 with open(file_name, "rb") as f: 278 gzipped = f.read(2) == b"\x1f\x8b" 279 280 if gzipped: 281 import gzip 282 283 return gzip.open(file_name, "rt", newline="") 284 285 return open(file_name, encoding="utf-8", newline="") 286 287 288@contextmanager 289def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 290 """ 291 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 292 293 Args: 294 read_csv: A `ReadCSV` function call. 295 296 Yields: 297 A python csv reader. 298 """ 299 args = read_csv.expressions 300 file = open_file(read_csv.name) 301 302 delimiter = "," 303 args = iter(arg.name for arg in args) # type: ignore 304 for k, v in zip(args, args): 305 if k == "delimiter": 306 delimiter = v 307 308 try: 309 import csv as csv_ 310 311 yield csv_.reader(file, delimiter=delimiter) 312 finally: 313 file.close() 314 315 316def find_new_name(taken: t.Collection[str], base: str) -> str: 317 """ 318 Searches for a new name. 319 320 Args: 321 taken: A collection of taken names. 322 base: Base name to alter. 323 324 Returns: 325 The new, available name. 326 """ 327 if base not in taken: 328 return base 329 330 i = 2 331 new = f"{base}_{i}" 332 while new in taken: 333 i += 1 334 new = f"{base}_{i}" 335 336 return new 337 338 339def is_int(text: str) -> bool: 340 return is_type(text, int) 341 342 343def is_float(text: str) -> bool: 344 return is_type(text, float) 345 346 347def is_type(text: str, target_type: t.Type) -> bool: 348 try: 349 target_type(text) 350 return True 351 except ValueError: 352 return False 353 354 355def name_sequence(prefix: str) -> t.Callable[[], str]: 356 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 357 sequence = count() 358 return lambda: f"{prefix}{next(sequence)}" 359 360 361def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 362 """Returns a dictionary created from an object's attributes.""" 363 return { 364 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 365 **kwargs, 366 } 367 368 369def split_num_words( 370 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 371) -> t.List[t.Optional[str]]: 372 """ 373 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 374 375 Args: 376 value: The value to be split. 377 sep: The value to use to split on. 378 min_num_words: The minimum number of words that are going to be in the result. 379 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 380 381 Examples: 382 >>> split_num_words("db.table", ".", 3) 383 [None, 'db', 'table'] 384 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 385 ['db', 'table', None] 386 >>> split_num_words("db.table", ".", 1) 387 ['db', 'table'] 388 389 Returns: 390 The list of words returned by `split`, possibly augmented by a number of `None` values. 391 """ 392 words = value.split(sep) 393 if fill_from_start: 394 return [None] * (min_num_words - len(words)) + words 395 return words + [None] * (min_num_words - len(words)) 396 397 398def is_iterable(value: t.Any) -> bool: 399 """ 400 Checks if the value is an iterable, excluding the types `str` and `bytes`. 401 402 Examples: 403 >>> is_iterable([1,2]) 404 True 405 >>> is_iterable("test") 406 False 407 408 Args: 409 value: The value to check if it is an iterable. 410 411 Returns: 412 A `bool` value indicating if it is an iterable. 413 """ 414 from sqlglot import Expression 415 416 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression)) 417 418 419def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 420 """ 421 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 422 type `str` and `bytes` are not regarded as iterables. 423 424 Examples: 425 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 426 [1, 2, 3, 4, 5, 'bla'] 427 >>> list(flatten([1, 2, 3])) 428 [1, 2, 3] 429 430 Args: 431 values: The value to be flattened. 432 433 Yields: 434 Non-iterable elements in `values`. 435 """ 436 for value in values: 437 if is_iterable(value): 438 yield from flatten(value) 439 else: 440 yield value 441 442 443def dict_depth(d: t.Dict) -> int: 444 """ 445 Get the nesting depth of a dictionary. 446 447 Example: 448 >>> dict_depth(None) 449 0 450 >>> dict_depth({}) 451 1 452 >>> dict_depth({"a": "b"}) 453 1 454 >>> dict_depth({"a": {}}) 455 2 456 >>> dict_depth({"a": {"b": {}}}) 457 3 458 """ 459 try: 460 return 1 + dict_depth(next(iter(d.values()))) 461 except AttributeError: 462 # d doesn't have attribute "values" 463 return 0 464 except StopIteration: 465 # d.values() returns an empty sequence 466 return 1 467 468 469def first(it: t.Iterable[T]) -> T: 470 """Returns the first element from an iterable (useful for sets).""" 471 return next(i for i in it) 472 473 474def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 475 if isinstance(value, bool) or value is None: 476 return value 477 478 # Coerce the value to boolean if it matches to the truthy/falsy values below 479 value_lower = value.lower() 480 if value_lower in ("true", "1"): 481 return True 482 if value_lower in ("false", "0"): 483 return False 484 485 return value 486 487 488def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 489 """ 490 Merges a sequence of ranges, represented as tuples (low, high) whose values 491 belong to some totally-ordered set. 492 493 Example: 494 >>> merge_ranges([(1, 3), (2, 6)]) 495 [(1, 6)] 496 """ 497 if not ranges: 498 return [] 499 500 ranges = sorted(ranges) 501 502 merged = [ranges[0]] 503 504 for start, end in ranges[1:]: 505 last_start, last_end = merged[-1] 506 507 if start <= last_end: 508 merged[-1] = (last_start, max(last_end, end)) 509 else: 510 merged.append((start, end)) 511 512 return merged 513 514 515def is_iso_date(text: str) -> bool: 516 try: 517 datetime.date.fromisoformat(text) 518 return True 519 except ValueError: 520 return False 521 522 523def is_iso_datetime(text: str) -> bool: 524 try: 525 datetime.datetime.fromisoformat(text) 526 return True 527 except ValueError: 528 return False 529 530 531# Interval units that operate on date components 532DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} 533 534 535def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: 536 return expression is not None and expression.name.lower() in DATE_UNITS 537 538 539K = t.TypeVar("K") 540V = t.TypeVar("V") 541 542 543class SingleValuedMapping(t.Mapping[K, V]): 544 """ 545 Mapping where all keys return the same value. 546 547 This rigamarole is meant to avoid copying keys, which was originally intended 548 as an optimization while qualifying columns for tables with lots of columns. 549 """ 550 551 def __init__(self, keys: t.Collection[K], value: V): 552 self._keys = keys if isinstance(keys, Set) else set(keys) 553 self._value = value 554 555 def __getitem__(self, key: K) -> V: 556 if key in self._keys: 557 return self._value 558 raise KeyError(key) 559 560 def __len__(self) -> int: 561 return len(self._keys) 562 563 def __iter__(self) -> t.Iterator[K]: 564 return iter(self._keys)
29class AutoName(Enum): 30 """ 31 This is used for creating Enum classes where `auto()` is the string form 32 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 33 34 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 35 """ 36 37 def _generate_next_value_(name, _start, _count, _last_values): 38 return name
This is used for creating Enum classes where auto()
is the string form
of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
41class classproperty(property): 42 """ 43 Similar to a normal property but works for class methods 44 """ 45 46 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 47 return classmethod(self.fget).__get__(None, owner)() # type: ignore
Similar to a normal property but works for class methods
50def suggest_closest_match_and_fail( 51 kind: str, 52 word: str, 53 possibilities: t.Iterable[str], 54) -> None: 55 close_matches = get_close_matches(word, possibilities, n=1) 56 57 similar = seq_get(close_matches, 0) or "" 58 if similar: 59 similar = f" Did you mean {similar}?" 60 61 raise ValueError(f"Unknown {kind} '{word}'.{similar}")
64def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 65 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 66 try: 67 return seq[index] 68 except IndexError: 69 return None
Returns the value in seq
at position index
, or None
if index
is out of bounds.
84def ensure_list(value): 85 """ 86 Ensures that a value is a list, otherwise casts or wraps it into one. 87 88 Args: 89 value: The value of interest. 90 91 Returns: 92 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 93 """ 94 if value is None: 95 return [] 96 if isinstance(value, (list, tuple)): 97 return list(value) 98 99 return [value]
Ensures that a value is a list, otherwise casts or wraps it into one.
Arguments:
- value: The value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
110def ensure_collection(value): 111 """ 112 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 113 114 Args: 115 value: The value of interest. 116 117 Returns: 118 The value if it's a collection, or else the value wrapped in a list. 119 """ 120 if value is None: 121 return [] 122 return ( 123 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 124 )
Ensures that a value is a collection (excluding str
and bytes
), otherwise wraps it into a list.
Arguments:
- value: The value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
127def csv(*args: str, sep: str = ", ") -> str: 128 """ 129 Formats any number of string arguments as CSV. 130 131 Args: 132 args: The string arguments to format. 133 sep: The argument separator. 134 135 Returns: 136 The arguments formatted as a CSV string. 137 """ 138 return sep.join(arg for arg in args if arg)
Formats any number of string arguments as CSV.
Arguments:
- args: The string arguments to format.
- sep: The argument separator.
Returns:
The arguments formatted as a CSV string.
141def subclasses( 142 module_name: str, 143 classes: t.Type | t.Tuple[t.Type, ...], 144 exclude: t.Type | t.Tuple[t.Type, ...] = (), 145) -> t.List[t.Type]: 146 """ 147 Returns all subclasses for a collection of classes, possibly excluding some of them. 148 149 Args: 150 module_name: The name of the module to search for subclasses in. 151 classes: Class(es) we want to find the subclasses of. 152 exclude: Class(es) we want to exclude from the returned list. 153 154 Returns: 155 The target subclasses. 156 """ 157 return [ 158 obj 159 for _, obj in inspect.getmembers( 160 sys.modules[module_name], 161 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 162 ) 163 ]
Returns all subclasses for a collection of classes, possibly excluding some of them.
Arguments:
- module_name: The name of the module to search for subclasses in.
- classes: Class(es) we want to find the subclasses of.
- exclude: Class(es) we want to exclude from the returned list.
Returns:
The target subclasses.
166def apply_index_offset( 167 this: exp.Expression, 168 expressions: t.List[E], 169 offset: int, 170 dialect: DialectType = None, 171) -> t.List[E]: 172 """ 173 Applies an offset to a given integer literal expression. 174 175 Args: 176 this: The target of the index. 177 expressions: The expression the offset will be applied to, wrapped in a list. 178 offset: The offset that will be applied. 179 dialect: the dialect of interest. 180 181 Returns: 182 The original expression with the offset applied to it, wrapped in a list. If the provided 183 `expressions` argument contains more than one expression, it's returned unaffected. 184 """ 185 if not offset or len(expressions) != 1: 186 return expressions 187 188 expression = expressions[0] 189 190 from sqlglot import exp 191 from sqlglot.optimizer.annotate_types import annotate_types 192 from sqlglot.optimizer.simplify import simplify 193 194 if not this.type: 195 annotate_types(this, dialect=dialect) 196 197 if t.cast(exp.DataType, this.type).this not in ( 198 exp.DataType.Type.UNKNOWN, 199 exp.DataType.Type.ARRAY, 200 ): 201 return expressions 202 203 if not expression.type: 204 annotate_types(expression, dialect=dialect) 205 206 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 207 logger.info("Applying array index offset (%s)", offset) 208 expression = simplify(expression + offset) 209 return [expression] 210 211 return expressions
Applies an offset to a given integer literal expression.
Arguments:
- this: The target of the index.
- expressions: The expression the offset will be applied to, wrapped in a list.
- offset: The offset that will be applied.
- dialect: the dialect of interest.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
expressions
argument contains more than one expression, it's returned unaffected.
214def camel_to_snake_case(name: str) -> str: 215 """Converts `name` from camelCase to snake_case and returns the result.""" 216 return CAMEL_CASE_PATTERN.sub("_", name).upper()
Converts name
from camelCase to snake_case and returns the result.
219def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 220 """ 221 Applies a transformation to a given expression until a fix point is reached. 222 223 Args: 224 expression: The expression to be transformed. 225 func: The transformation to be applied. 226 227 Returns: 228 The transformed expression. 229 """ 230 231 while True: 232 start_hash = hash(expression) 233 expression = func(expression) 234 end_hash = hash(expression) 235 236 if start_hash == end_hash: 237 break 238 239 return expression
Applies a transformation to a given expression until a fix point is reached.
Arguments:
- expression: The expression to be transformed.
- func: The transformation to be applied.
Returns:
The transformed expression.
242def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 243 """ 244 Sorts a given directed acyclic graph in topological order. 245 246 Args: 247 dag: The graph to be sorted. 248 249 Returns: 250 A list that contains all of the graph's nodes in topological order. 251 """ 252 result = [] 253 254 for node, deps in tuple(dag.items()): 255 for dep in deps: 256 if dep not in dag: 257 dag[dep] = set() 258 259 while dag: 260 current = {node for node, deps in dag.items() if not deps} 261 262 if not current: 263 raise ValueError("Cycle error") 264 265 for node in current: 266 dag.pop(node) 267 268 for deps in dag.values(): 269 deps -= current 270 271 result.extend(sorted(current)) # type: ignore 272 273 return result
Sorts a given directed acyclic graph in topological order.
Arguments:
- dag: The graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
276def open_file(file_name: str) -> t.TextIO: 277 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 278 with open(file_name, "rb") as f: 279 gzipped = f.read(2) == b"\x1f\x8b" 280 281 if gzipped: 282 import gzip 283 284 return gzip.open(file_name, "rt", newline="") 285 286 return open(file_name, encoding="utf-8", newline="")
Open a file that may be compressed as gzip and return it in universal newline mode.
289@contextmanager 290def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 291 """ 292 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 293 294 Args: 295 read_csv: A `ReadCSV` function call. 296 297 Yields: 298 A python csv reader. 299 """ 300 args = read_csv.expressions 301 file = open_file(read_csv.name) 302 303 delimiter = "," 304 args = iter(arg.name for arg in args) # type: ignore 305 for k, v in zip(args, args): 306 if k == "delimiter": 307 delimiter = v 308 309 try: 310 import csv as csv_ 311 312 yield csv_.reader(file, delimiter=delimiter) 313 finally: 314 file.close()
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
.
Arguments:
- read_csv: A
ReadCSV
function call.
Yields:
A python csv reader.
317def find_new_name(taken: t.Collection[str], base: str) -> str: 318 """ 319 Searches for a new name. 320 321 Args: 322 taken: A collection of taken names. 323 base: Base name to alter. 324 325 Returns: 326 The new, available name. 327 """ 328 if base not in taken: 329 return base 330 331 i = 2 332 new = f"{base}_{i}" 333 while new in taken: 334 i += 1 335 new = f"{base}_{i}" 336 337 return new
Searches for a new name.
Arguments:
- taken: A collection of taken names.
- base: Base name to alter.
Returns:
The new, available name.
356def name_sequence(prefix: str) -> t.Callable[[], str]: 357 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 358 sequence = count() 359 return lambda: f"{prefix}{next(sequence)}"
Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").
362def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 363 """Returns a dictionary created from an object's attributes.""" 364 return { 365 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 366 **kwargs, 367 }
Returns a dictionary created from an object's attributes.
370def split_num_words( 371 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 372) -> t.List[t.Optional[str]]: 373 """ 374 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 375 376 Args: 377 value: The value to be split. 378 sep: The value to use to split on. 379 min_num_words: The minimum number of words that are going to be in the result. 380 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 381 382 Examples: 383 >>> split_num_words("db.table", ".", 3) 384 [None, 'db', 'table'] 385 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 386 ['db', 'table', None] 387 >>> split_num_words("db.table", ".", 1) 388 ['db', 'table'] 389 390 Returns: 391 The list of words returned by `split`, possibly augmented by a number of `None` values. 392 """ 393 words = value.split(sep) 394 if fill_from_start: 395 return [None] * (min_num_words - len(words)) + words 396 return words + [None] * (min_num_words - len(words))
Perform a split on a value and return N words as a result with None
used for words that don't exist.
Arguments:
- value: The value to be split.
- sep: The value to use to split on.
- min_num_words: The minimum number of words that are going to be in the result.
- fill_from_start: Indicates that if
None
values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3) [None, 'db', 'table'] >>> split_num_words("db.table", ".", 3, fill_from_start=False) ['db', 'table', None] >>> split_num_words("db.table", ".", 1) ['db', 'table']
Returns:
The list of words returned by
split
, possibly augmented by a number ofNone
values.
399def is_iterable(value: t.Any) -> bool: 400 """ 401 Checks if the value is an iterable, excluding the types `str` and `bytes`. 402 403 Examples: 404 >>> is_iterable([1,2]) 405 True 406 >>> is_iterable("test") 407 False 408 409 Args: 410 value: The value to check if it is an iterable. 411 412 Returns: 413 A `bool` value indicating if it is an iterable. 414 """ 415 from sqlglot import Expression 416 417 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
Checks if the value is an iterable, excluding the types str
and bytes
.
Examples:
>>> is_iterable([1,2]) True >>> is_iterable("test") False
Arguments:
- value: The value to check if it is an iterable.
Returns:
A
bool
value indicating if it is an iterable.
420def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 421 """ 422 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 423 type `str` and `bytes` are not regarded as iterables. 424 425 Examples: 426 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 427 [1, 2, 3, 4, 5, 'bla'] 428 >>> list(flatten([1, 2, 3])) 429 [1, 2, 3] 430 431 Args: 432 values: The value to be flattened. 433 434 Yields: 435 Non-iterable elements in `values`. 436 """ 437 for value in values: 438 if is_iterable(value): 439 yield from flatten(value) 440 else: 441 yield value
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type str
and bytes
are not regarded as iterables.
Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) [1, 2, 3, 4, 5, 'bla'] >>> list(flatten([1, 2, 3])) [1, 2, 3]
Arguments:
- values: The value to be flattened.
Yields:
Non-iterable elements in
values
.
444def dict_depth(d: t.Dict) -> int: 445 """ 446 Get the nesting depth of a dictionary. 447 448 Example: 449 >>> dict_depth(None) 450 0 451 >>> dict_depth({}) 452 1 453 >>> dict_depth({"a": "b"}) 454 1 455 >>> dict_depth({"a": {}}) 456 2 457 >>> dict_depth({"a": {"b": {}}}) 458 3 459 """ 460 try: 461 return 1 + dict_depth(next(iter(d.values()))) 462 except AttributeError: 463 # d doesn't have attribute "values" 464 return 0 465 except StopIteration: 466 # d.values() returns an empty sequence 467 return 1
Get the nesting depth of a dictionary.
Example:
>>> dict_depth(None) 0 >>> dict_depth({}) 1 >>> dict_depth({"a": "b"}) 1 >>> dict_depth({"a": {}}) 2 >>> dict_depth({"a": {"b": {}}}) 3
470def first(it: t.Iterable[T]) -> T: 471 """Returns the first element from an iterable (useful for sets).""" 472 return next(i for i in it)
Returns the first element from an iterable (useful for sets).
475def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 476 if isinstance(value, bool) or value is None: 477 return value 478 479 # Coerce the value to boolean if it matches to the truthy/falsy values below 480 value_lower = value.lower() 481 if value_lower in ("true", "1"): 482 return True 483 if value_lower in ("false", "0"): 484 return False 485 486 return value
489def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 490 """ 491 Merges a sequence of ranges, represented as tuples (low, high) whose values 492 belong to some totally-ordered set. 493 494 Example: 495 >>> merge_ranges([(1, 3), (2, 6)]) 496 [(1, 6)] 497 """ 498 if not ranges: 499 return [] 500 501 ranges = sorted(ranges) 502 503 merged = [ranges[0]] 504 505 for start, end in ranges[1:]: 506 last_start, last_end = merged[-1] 507 508 if start <= last_end: 509 merged[-1] = (last_start, max(last_end, end)) 510 else: 511 merged.append((start, end)) 512 513 return merged
Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.
Example:
>>> merge_ranges([(1, 3), (2, 6)]) [(1, 6)]
544class SingleValuedMapping(t.Mapping[K, V]): 545 """ 546 Mapping where all keys return the same value. 547 548 This rigamarole is meant to avoid copying keys, which was originally intended 549 as an optimization while qualifying columns for tables with lots of columns. 550 """ 551 552 def __init__(self, keys: t.Collection[K], value: V): 553 self._keys = keys if isinstance(keys, Set) else set(keys) 554 self._value = value 555 556 def __getitem__(self, key: K) -> V: 557 if key in self._keys: 558 return self._value 559 raise KeyError(key) 560 561 def __len__(self) -> int: 562 return len(self._keys) 563 564 def __iter__(self) -> t.Iterator[K]: 565 return iter(self._keys)
Mapping where all keys return the same value.
This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.