sqlglot.helper
1from __future__ import annotations 2 3import datetime 4import inspect 5import logging 6import re 7import sys 8import typing as t 9from collections.abc import Collection, Set 10from contextlib import contextmanager 11from copy import copy 12from difflib import get_close_matches 13from enum import Enum 14from itertools import count 15 16if t.TYPE_CHECKING: 17 from sqlglot import exp 18 from sqlglot._typing import A, E, T 19 from sqlglot.dialects.dialect import DialectType 20 from sqlglot.expressions import Expression 21 22 23CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") 24PYTHON_VERSION = sys.version_info[:2] 25logger = logging.getLogger("sqlglot") 26 27 28class AutoName(Enum): 29 """ 30 This is used for creating Enum classes where `auto()` is the string form 31 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 32 33 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 34 """ 35 36 def _generate_next_value_(name, _start, _count, _last_values): 37 return name 38 39 40class classproperty(property): 41 """ 42 Similar to a normal property but works for class methods 43 """ 44 45 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 46 return classmethod(self.fget).__get__(None, owner)() # type: ignore 47 48 49def suggest_closest_match_and_fail( 50 kind: str, 51 word: str, 52 possibilities: t.Iterable[str], 53) -> None: 54 close_matches = get_close_matches(word, possibilities, n=1) 55 56 similar = seq_get(close_matches, 0) or "" 57 if similar: 58 similar = f" Did you mean {similar}?" 59 60 raise ValueError(f"Unknown {kind} '{word}'.{similar}") 61 62 63def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 64 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 65 try: 66 return seq[index] 67 except IndexError: 68 return None 69 70 71@t.overload 72def ensure_list(value: t.Collection[T]) -> t.List[T]: ... 73 74 75@t.overload 76def ensure_list(value: None) -> t.List: ... 77 78 79@t.overload 80def ensure_list(value: T) -> t.List[T]: ... 81 82 83def ensure_list(value): 84 """ 85 Ensures that a value is a list, otherwise casts or wraps it into one. 86 87 Args: 88 value: The value of interest. 89 90 Returns: 91 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 92 """ 93 if value is None: 94 return [] 95 if isinstance(value, (list, tuple)): 96 return list(value) 97 98 return [value] 99 100 101@t.overload 102def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... 103 104 105@t.overload 106def ensure_collection(value: T) -> t.Collection[T]: ... 107 108 109def ensure_collection(value): 110 """ 111 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 112 113 Args: 114 value: The value of interest. 115 116 Returns: 117 The value if it's a collection, or else the value wrapped in a list. 118 """ 119 if value is None: 120 return [] 121 return ( 122 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 123 ) 124 125 126def csv(*args: str, sep: str = ", ") -> str: 127 """ 128 Formats any number of string arguments as CSV. 129 130 Args: 131 args: The string arguments to format. 132 sep: The argument separator. 133 134 Returns: 135 The arguments formatted as a CSV string. 136 """ 137 return sep.join(arg for arg in args if arg) 138 139 140def subclasses( 141 module_name: str, 142 classes: t.Type | t.Tuple[t.Type, ...], 143 exclude: t.Type | t.Tuple[t.Type, ...] = (), 144) -> t.List[t.Type]: 145 """ 146 Returns all subclasses for a collection of classes, possibly excluding some of them. 147 148 Args: 149 module_name: The name of the module to search for subclasses in. 150 classes: Class(es) we want to find the subclasses of. 151 exclude: Class(es) we want to exclude from the returned list. 152 153 Returns: 154 The target subclasses. 155 """ 156 return [ 157 obj 158 for _, obj in inspect.getmembers( 159 sys.modules[module_name], 160 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 161 ) 162 ] 163 164 165def apply_index_offset( 166 this: exp.Expression, 167 expressions: t.List[E], 168 offset: int, 169 dialect: DialectType = None, 170) -> t.List[E]: 171 """ 172 Applies an offset to a given integer literal expression. 173 174 Args: 175 this: The target of the index. 176 expressions: The expression the offset will be applied to, wrapped in a list. 177 offset: The offset that will be applied. 178 dialect: the dialect of interest. 179 180 Returns: 181 The original expression with the offset applied to it, wrapped in a list. If the provided 182 `expressions` argument contains more than one expression, it's returned unaffected. 183 """ 184 if not offset or len(expressions) != 1: 185 return expressions 186 187 expression = expressions[0] 188 189 from sqlglot import exp 190 from sqlglot.optimizer.annotate_types import annotate_types 191 from sqlglot.optimizer.simplify import simplify 192 193 if not this.type: 194 annotate_types(this, dialect=dialect) 195 196 if t.cast(exp.DataType, this.type).this not in ( 197 exp.DataType.Type.UNKNOWN, 198 exp.DataType.Type.ARRAY, 199 ): 200 return expressions 201 202 if not expression.type: 203 annotate_types(expression, dialect=dialect) 204 205 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 206 logger.info("Applying array index offset (%s)", offset) 207 expression = simplify(expression + offset) 208 return [expression] 209 210 return expressions 211 212 213def camel_to_snake_case(name: str) -> str: 214 """Converts `name` from camelCase to snake_case and returns the result.""" 215 return CAMEL_CASE_PATTERN.sub("_", name).upper() 216 217 218def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 219 """ 220 Applies a transformation to a given expression until a fix point is reached. 221 222 Args: 223 expression: The expression to be transformed. 224 func: The transformation to be applied. 225 226 Returns: 227 The transformed expression. 228 """ 229 end_hash: t.Optional[int] = None 230 231 while True: 232 # No need to walk the AST– we've already cached the hashes in the previous iteration 233 if end_hash is None: 234 for n in reversed(tuple(expression.walk())): 235 n._hash = hash(n) 236 237 start_hash = hash(expression) 238 expression = func(expression) 239 240 expression_nodes = tuple(expression.walk()) 241 242 # Uncache previous caches so we can recompute them 243 for n in reversed(expression_nodes): 244 n._hash = None 245 n._hash = hash(n) 246 247 end_hash = hash(expression) 248 249 if start_hash == end_hash: 250 # ... and reset the hash so we don't risk it becoming out of date if a mutation happens 251 for n in expression_nodes: 252 n._hash = None 253 254 break 255 256 return expression 257 258 259def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 260 """ 261 Sorts a given directed acyclic graph in topological order. 262 263 Args: 264 dag: The graph to be sorted. 265 266 Returns: 267 A list that contains all of the graph's nodes in topological order. 268 """ 269 result = [] 270 271 for node, deps in tuple(dag.items()): 272 for dep in deps: 273 if dep not in dag: 274 dag[dep] = set() 275 276 while dag: 277 current = {node for node, deps in dag.items() if not deps} 278 279 if not current: 280 raise ValueError("Cycle error") 281 282 for node in current: 283 dag.pop(node) 284 285 for deps in dag.values(): 286 deps -= current 287 288 result.extend(sorted(current)) # type: ignore 289 290 return result 291 292 293def open_file(file_name: str) -> t.TextIO: 294 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 295 with open(file_name, "rb") as f: 296 gzipped = f.read(2) == b"\x1f\x8b" 297 298 if gzipped: 299 import gzip 300 301 return gzip.open(file_name, "rt", newline="") 302 303 return open(file_name, encoding="utf-8", newline="") 304 305 306@contextmanager 307def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 308 """ 309 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 310 311 Args: 312 read_csv: A `ReadCSV` function call. 313 314 Yields: 315 A python csv reader. 316 """ 317 args = read_csv.expressions 318 file = open_file(read_csv.name) 319 320 delimiter = "," 321 args = iter(arg.name for arg in args) # type: ignore 322 for k, v in zip(args, args): 323 if k == "delimiter": 324 delimiter = v 325 326 try: 327 import csv as csv_ 328 329 yield csv_.reader(file, delimiter=delimiter) 330 finally: 331 file.close() 332 333 334def find_new_name(taken: t.Collection[str], base: str) -> str: 335 """ 336 Searches for a new name. 337 338 Args: 339 taken: A collection of taken names. 340 base: Base name to alter. 341 342 Returns: 343 The new, available name. 344 """ 345 if base not in taken: 346 return base 347 348 i = 2 349 new = f"{base}_{i}" 350 while new in taken: 351 i += 1 352 new = f"{base}_{i}" 353 354 return new 355 356 357def is_int(text: str) -> bool: 358 return is_type(text, int) 359 360 361def is_float(text: str) -> bool: 362 return is_type(text, float) 363 364 365def is_type(text: str, target_type: t.Type) -> bool: 366 try: 367 target_type(text) 368 return True 369 except ValueError: 370 return False 371 372 373def name_sequence(prefix: str) -> t.Callable[[], str]: 374 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 375 sequence = count() 376 return lambda: f"{prefix}{next(sequence)}" 377 378 379def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 380 """Returns a dictionary created from an object's attributes.""" 381 return { 382 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 383 **kwargs, 384 } 385 386 387def split_num_words( 388 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 389) -> t.List[t.Optional[str]]: 390 """ 391 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 392 393 Args: 394 value: The value to be split. 395 sep: The value to use to split on. 396 min_num_words: The minimum number of words that are going to be in the result. 397 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 398 399 Examples: 400 >>> split_num_words("db.table", ".", 3) 401 [None, 'db', 'table'] 402 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 403 ['db', 'table', None] 404 >>> split_num_words("db.table", ".", 1) 405 ['db', 'table'] 406 407 Returns: 408 The list of words returned by `split`, possibly augmented by a number of `None` values. 409 """ 410 words = value.split(sep) 411 if fill_from_start: 412 return [None] * (min_num_words - len(words)) + words 413 return words + [None] * (min_num_words - len(words)) 414 415 416def is_iterable(value: t.Any) -> bool: 417 """ 418 Checks if the value is an iterable, excluding the types `str` and `bytes`. 419 420 Examples: 421 >>> is_iterable([1,2]) 422 True 423 >>> is_iterable("test") 424 False 425 426 Args: 427 value: The value to check if it is an iterable. 428 429 Returns: 430 A `bool` value indicating if it is an iterable. 431 """ 432 from sqlglot import Expression 433 434 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression)) 435 436 437def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 438 """ 439 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 440 type `str` and `bytes` are not regarded as iterables. 441 442 Examples: 443 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 444 [1, 2, 3, 4, 5, 'bla'] 445 >>> list(flatten([1, 2, 3])) 446 [1, 2, 3] 447 448 Args: 449 values: The value to be flattened. 450 451 Yields: 452 Non-iterable elements in `values`. 453 """ 454 for value in values: 455 if is_iterable(value): 456 yield from flatten(value) 457 else: 458 yield value 459 460 461def dict_depth(d: t.Dict) -> int: 462 """ 463 Get the nesting depth of a dictionary. 464 465 Example: 466 >>> dict_depth(None) 467 0 468 >>> dict_depth({}) 469 1 470 >>> dict_depth({"a": "b"}) 471 1 472 >>> dict_depth({"a": {}}) 473 2 474 >>> dict_depth({"a": {"b": {}}}) 475 3 476 """ 477 try: 478 return 1 + dict_depth(next(iter(d.values()))) 479 except AttributeError: 480 # d doesn't have attribute "values" 481 return 0 482 except StopIteration: 483 # d.values() returns an empty sequence 484 return 1 485 486 487def first(it: t.Iterable[T]) -> T: 488 """Returns the first element from an iterable (useful for sets).""" 489 return next(i for i in it) 490 491 492def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 493 if isinstance(value, bool) or value is None: 494 return value 495 496 # Coerce the value to boolean if it matches to the truthy/falsy values below 497 value_lower = value.lower() 498 if value_lower in ("true", "1"): 499 return True 500 if value_lower in ("false", "0"): 501 return False 502 503 return value 504 505 506def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 507 """ 508 Merges a sequence of ranges, represented as tuples (low, high) whose values 509 belong to some totally-ordered set. 510 511 Example: 512 >>> merge_ranges([(1, 3), (2, 6)]) 513 [(1, 6)] 514 """ 515 if not ranges: 516 return [] 517 518 ranges = sorted(ranges) 519 520 merged = [ranges[0]] 521 522 for start, end in ranges[1:]: 523 last_start, last_end = merged[-1] 524 525 if start <= last_end: 526 merged[-1] = (last_start, max(last_end, end)) 527 else: 528 merged.append((start, end)) 529 530 return merged 531 532 533def is_iso_date(text: str) -> bool: 534 try: 535 datetime.date.fromisoformat(text) 536 return True 537 except ValueError: 538 return False 539 540 541def is_iso_datetime(text: str) -> bool: 542 try: 543 datetime.datetime.fromisoformat(text) 544 return True 545 except ValueError: 546 return False 547 548 549# Interval units that operate on date components 550DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} 551 552 553def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: 554 return expression is not None and expression.name.lower() in DATE_UNITS 555 556 557K = t.TypeVar("K") 558V = t.TypeVar("V") 559 560 561class SingleValuedMapping(t.Mapping[K, V]): 562 """ 563 Mapping where all keys return the same value. 564 565 This rigamarole is meant to avoid copying keys, which was originally intended 566 as an optimization while qualifying columns for tables with lots of columns. 567 """ 568 569 def __init__(self, keys: t.Collection[K], value: V): 570 self._keys = keys if isinstance(keys, Set) else set(keys) 571 self._value = value 572 573 def __getitem__(self, key: K) -> V: 574 if key in self._keys: 575 return self._value 576 raise KeyError(key) 577 578 def __len__(self) -> int: 579 return len(self._keys) 580 581 def __iter__(self) -> t.Iterator[K]: 582 return iter(self._keys)
29class AutoName(Enum): 30 """ 31 This is used for creating Enum classes where `auto()` is the string form 32 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 33 34 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 35 """ 36 37 def _generate_next_value_(name, _start, _count, _last_values): 38 return name
This is used for creating Enum classes where auto()
is the string form
of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
41class classproperty(property): 42 """ 43 Similar to a normal property but works for class methods 44 """ 45 46 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 47 return classmethod(self.fget).__get__(None, owner)() # type: ignore
Similar to a normal property but works for class methods
50def suggest_closest_match_and_fail( 51 kind: str, 52 word: str, 53 possibilities: t.Iterable[str], 54) -> None: 55 close_matches = get_close_matches(word, possibilities, n=1) 56 57 similar = seq_get(close_matches, 0) or "" 58 if similar: 59 similar = f" Did you mean {similar}?" 60 61 raise ValueError(f"Unknown {kind} '{word}'.{similar}")
64def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 65 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 66 try: 67 return seq[index] 68 except IndexError: 69 return None
Returns the value in seq
at position index
, or None
if index
is out of bounds.
84def ensure_list(value): 85 """ 86 Ensures that a value is a list, otherwise casts or wraps it into one. 87 88 Args: 89 value: The value of interest. 90 91 Returns: 92 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 93 """ 94 if value is None: 95 return [] 96 if isinstance(value, (list, tuple)): 97 return list(value) 98 99 return [value]
Ensures that a value is a list, otherwise casts or wraps it into one.
Arguments:
- value: The value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
110def ensure_collection(value): 111 """ 112 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 113 114 Args: 115 value: The value of interest. 116 117 Returns: 118 The value if it's a collection, or else the value wrapped in a list. 119 """ 120 if value is None: 121 return [] 122 return ( 123 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 124 )
Ensures that a value is a collection (excluding str
and bytes
), otherwise wraps it into a list.
Arguments:
- value: The value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
127def csv(*args: str, sep: str = ", ") -> str: 128 """ 129 Formats any number of string arguments as CSV. 130 131 Args: 132 args: The string arguments to format. 133 sep: The argument separator. 134 135 Returns: 136 The arguments formatted as a CSV string. 137 """ 138 return sep.join(arg for arg in args if arg)
Formats any number of string arguments as CSV.
Arguments:
- args: The string arguments to format.
- sep: The argument separator.
Returns:
The arguments formatted as a CSV string.
141def subclasses( 142 module_name: str, 143 classes: t.Type | t.Tuple[t.Type, ...], 144 exclude: t.Type | t.Tuple[t.Type, ...] = (), 145) -> t.List[t.Type]: 146 """ 147 Returns all subclasses for a collection of classes, possibly excluding some of them. 148 149 Args: 150 module_name: The name of the module to search for subclasses in. 151 classes: Class(es) we want to find the subclasses of. 152 exclude: Class(es) we want to exclude from the returned list. 153 154 Returns: 155 The target subclasses. 156 """ 157 return [ 158 obj 159 for _, obj in inspect.getmembers( 160 sys.modules[module_name], 161 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 162 ) 163 ]
Returns all subclasses for a collection of classes, possibly excluding some of them.
Arguments:
- module_name: The name of the module to search for subclasses in.
- classes: Class(es) we want to find the subclasses of.
- exclude: Class(es) we want to exclude from the returned list.
Returns:
The target subclasses.
166def apply_index_offset( 167 this: exp.Expression, 168 expressions: t.List[E], 169 offset: int, 170 dialect: DialectType = None, 171) -> t.List[E]: 172 """ 173 Applies an offset to a given integer literal expression. 174 175 Args: 176 this: The target of the index. 177 expressions: The expression the offset will be applied to, wrapped in a list. 178 offset: The offset that will be applied. 179 dialect: the dialect of interest. 180 181 Returns: 182 The original expression with the offset applied to it, wrapped in a list. If the provided 183 `expressions` argument contains more than one expression, it's returned unaffected. 184 """ 185 if not offset or len(expressions) != 1: 186 return expressions 187 188 expression = expressions[0] 189 190 from sqlglot import exp 191 from sqlglot.optimizer.annotate_types import annotate_types 192 from sqlglot.optimizer.simplify import simplify 193 194 if not this.type: 195 annotate_types(this, dialect=dialect) 196 197 if t.cast(exp.DataType, this.type).this not in ( 198 exp.DataType.Type.UNKNOWN, 199 exp.DataType.Type.ARRAY, 200 ): 201 return expressions 202 203 if not expression.type: 204 annotate_types(expression, dialect=dialect) 205 206 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 207 logger.info("Applying array index offset (%s)", offset) 208 expression = simplify(expression + offset) 209 return [expression] 210 211 return expressions
Applies an offset to a given integer literal expression.
Arguments:
- this: The target of the index.
- expressions: The expression the offset will be applied to, wrapped in a list.
- offset: The offset that will be applied.
- dialect: the dialect of interest.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
expressions
argument contains more than one expression, it's returned unaffected.
214def camel_to_snake_case(name: str) -> str: 215 """Converts `name` from camelCase to snake_case and returns the result.""" 216 return CAMEL_CASE_PATTERN.sub("_", name).upper()
Converts name
from camelCase to snake_case and returns the result.
219def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 220 """ 221 Applies a transformation to a given expression until a fix point is reached. 222 223 Args: 224 expression: The expression to be transformed. 225 func: The transformation to be applied. 226 227 Returns: 228 The transformed expression. 229 """ 230 end_hash: t.Optional[int] = None 231 232 while True: 233 # No need to walk the AST– we've already cached the hashes in the previous iteration 234 if end_hash is None: 235 for n in reversed(tuple(expression.walk())): 236 n._hash = hash(n) 237 238 start_hash = hash(expression) 239 expression = func(expression) 240 241 expression_nodes = tuple(expression.walk()) 242 243 # Uncache previous caches so we can recompute them 244 for n in reversed(expression_nodes): 245 n._hash = None 246 n._hash = hash(n) 247 248 end_hash = hash(expression) 249 250 if start_hash == end_hash: 251 # ... and reset the hash so we don't risk it becoming out of date if a mutation happens 252 for n in expression_nodes: 253 n._hash = None 254 255 break 256 257 return expression
Applies a transformation to a given expression until a fix point is reached.
Arguments:
- expression: The expression to be transformed.
- func: The transformation to be applied.
Returns:
The transformed expression.
260def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 261 """ 262 Sorts a given directed acyclic graph in topological order. 263 264 Args: 265 dag: The graph to be sorted. 266 267 Returns: 268 A list that contains all of the graph's nodes in topological order. 269 """ 270 result = [] 271 272 for node, deps in tuple(dag.items()): 273 for dep in deps: 274 if dep not in dag: 275 dag[dep] = set() 276 277 while dag: 278 current = {node for node, deps in dag.items() if not deps} 279 280 if not current: 281 raise ValueError("Cycle error") 282 283 for node in current: 284 dag.pop(node) 285 286 for deps in dag.values(): 287 deps -= current 288 289 result.extend(sorted(current)) # type: ignore 290 291 return result
Sorts a given directed acyclic graph in topological order.
Arguments:
- dag: The graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
294def open_file(file_name: str) -> t.TextIO: 295 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 296 with open(file_name, "rb") as f: 297 gzipped = f.read(2) == b"\x1f\x8b" 298 299 if gzipped: 300 import gzip 301 302 return gzip.open(file_name, "rt", newline="") 303 304 return open(file_name, encoding="utf-8", newline="")
Open a file that may be compressed as gzip and return it in universal newline mode.
307@contextmanager 308def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 309 """ 310 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 311 312 Args: 313 read_csv: A `ReadCSV` function call. 314 315 Yields: 316 A python csv reader. 317 """ 318 args = read_csv.expressions 319 file = open_file(read_csv.name) 320 321 delimiter = "," 322 args = iter(arg.name for arg in args) # type: ignore 323 for k, v in zip(args, args): 324 if k == "delimiter": 325 delimiter = v 326 327 try: 328 import csv as csv_ 329 330 yield csv_.reader(file, delimiter=delimiter) 331 finally: 332 file.close()
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
.
Arguments:
- read_csv: A
ReadCSV
function call.
Yields:
A python csv reader.
335def find_new_name(taken: t.Collection[str], base: str) -> str: 336 """ 337 Searches for a new name. 338 339 Args: 340 taken: A collection of taken names. 341 base: Base name to alter. 342 343 Returns: 344 The new, available name. 345 """ 346 if base not in taken: 347 return base 348 349 i = 2 350 new = f"{base}_{i}" 351 while new in taken: 352 i += 1 353 new = f"{base}_{i}" 354 355 return new
Searches for a new name.
Arguments:
- taken: A collection of taken names.
- base: Base name to alter.
Returns:
The new, available name.
374def name_sequence(prefix: str) -> t.Callable[[], str]: 375 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 376 sequence = count() 377 return lambda: f"{prefix}{next(sequence)}"
Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").
380def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 381 """Returns a dictionary created from an object's attributes.""" 382 return { 383 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 384 **kwargs, 385 }
Returns a dictionary created from an object's attributes.
388def split_num_words( 389 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 390) -> t.List[t.Optional[str]]: 391 """ 392 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 393 394 Args: 395 value: The value to be split. 396 sep: The value to use to split on. 397 min_num_words: The minimum number of words that are going to be in the result. 398 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 399 400 Examples: 401 >>> split_num_words("db.table", ".", 3) 402 [None, 'db', 'table'] 403 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 404 ['db', 'table', None] 405 >>> split_num_words("db.table", ".", 1) 406 ['db', 'table'] 407 408 Returns: 409 The list of words returned by `split`, possibly augmented by a number of `None` values. 410 """ 411 words = value.split(sep) 412 if fill_from_start: 413 return [None] * (min_num_words - len(words)) + words 414 return words + [None] * (min_num_words - len(words))
Perform a split on a value and return N words as a result with None
used for words that don't exist.
Arguments:
- value: The value to be split.
- sep: The value to use to split on.
- min_num_words: The minimum number of words that are going to be in the result.
- fill_from_start: Indicates that if
None
values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3) [None, 'db', 'table'] >>> split_num_words("db.table", ".", 3, fill_from_start=False) ['db', 'table', None] >>> split_num_words("db.table", ".", 1) ['db', 'table']
Returns:
The list of words returned by
split
, possibly augmented by a number ofNone
values.
417def is_iterable(value: t.Any) -> bool: 418 """ 419 Checks if the value is an iterable, excluding the types `str` and `bytes`. 420 421 Examples: 422 >>> is_iterable([1,2]) 423 True 424 >>> is_iterable("test") 425 False 426 427 Args: 428 value: The value to check if it is an iterable. 429 430 Returns: 431 A `bool` value indicating if it is an iterable. 432 """ 433 from sqlglot import Expression 434 435 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
Checks if the value is an iterable, excluding the types str
and bytes
.
Examples:
>>> is_iterable([1,2]) True >>> is_iterable("test") False
Arguments:
- value: The value to check if it is an iterable.
Returns:
A
bool
value indicating if it is an iterable.
438def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 439 """ 440 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 441 type `str` and `bytes` are not regarded as iterables. 442 443 Examples: 444 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 445 [1, 2, 3, 4, 5, 'bla'] 446 >>> list(flatten([1, 2, 3])) 447 [1, 2, 3] 448 449 Args: 450 values: The value to be flattened. 451 452 Yields: 453 Non-iterable elements in `values`. 454 """ 455 for value in values: 456 if is_iterable(value): 457 yield from flatten(value) 458 else: 459 yield value
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type str
and bytes
are not regarded as iterables.
Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) [1, 2, 3, 4, 5, 'bla'] >>> list(flatten([1, 2, 3])) [1, 2, 3]
Arguments:
- values: The value to be flattened.
Yields:
Non-iterable elements in
values
.
462def dict_depth(d: t.Dict) -> int: 463 """ 464 Get the nesting depth of a dictionary. 465 466 Example: 467 >>> dict_depth(None) 468 0 469 >>> dict_depth({}) 470 1 471 >>> dict_depth({"a": "b"}) 472 1 473 >>> dict_depth({"a": {}}) 474 2 475 >>> dict_depth({"a": {"b": {}}}) 476 3 477 """ 478 try: 479 return 1 + dict_depth(next(iter(d.values()))) 480 except AttributeError: 481 # d doesn't have attribute "values" 482 return 0 483 except StopIteration: 484 # d.values() returns an empty sequence 485 return 1
Get the nesting depth of a dictionary.
Example:
>>> dict_depth(None) 0 >>> dict_depth({}) 1 >>> dict_depth({"a": "b"}) 1 >>> dict_depth({"a": {}}) 2 >>> dict_depth({"a": {"b": {}}}) 3
488def first(it: t.Iterable[T]) -> T: 489 """Returns the first element from an iterable (useful for sets).""" 490 return next(i for i in it)
Returns the first element from an iterable (useful for sets).
493def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 494 if isinstance(value, bool) or value is None: 495 return value 496 497 # Coerce the value to boolean if it matches to the truthy/falsy values below 498 value_lower = value.lower() 499 if value_lower in ("true", "1"): 500 return True 501 if value_lower in ("false", "0"): 502 return False 503 504 return value
507def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 508 """ 509 Merges a sequence of ranges, represented as tuples (low, high) whose values 510 belong to some totally-ordered set. 511 512 Example: 513 >>> merge_ranges([(1, 3), (2, 6)]) 514 [(1, 6)] 515 """ 516 if not ranges: 517 return [] 518 519 ranges = sorted(ranges) 520 521 merged = [ranges[0]] 522 523 for start, end in ranges[1:]: 524 last_start, last_end = merged[-1] 525 526 if start <= last_end: 527 merged[-1] = (last_start, max(last_end, end)) 528 else: 529 merged.append((start, end)) 530 531 return merged
Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.
Example:
>>> merge_ranges([(1, 3), (2, 6)]) [(1, 6)]
562class SingleValuedMapping(t.Mapping[K, V]): 563 """ 564 Mapping where all keys return the same value. 565 566 This rigamarole is meant to avoid copying keys, which was originally intended 567 as an optimization while qualifying columns for tables with lots of columns. 568 """ 569 570 def __init__(self, keys: t.Collection[K], value: V): 571 self._keys = keys if isinstance(keys, Set) else set(keys) 572 self._value = value 573 574 def __getitem__(self, key: K) -> V: 575 if key in self._keys: 576 return self._value 577 raise KeyError(key) 578 579 def __len__(self) -> int: 580 return len(self._keys) 581 582 def __iter__(self) -> t.Iterator[K]: 583 return iter(self._keys)
Mapping where all keys return the same value.
This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.