sqlglot.helper
1from __future__ import annotations 2 3import datetime 4import inspect 5import logging 6import re 7import sys 8import typing as t 9from collections.abc import Collection, Set 10from contextlib import contextmanager 11from copy import copy 12from enum import Enum 13from itertools import count 14 15if t.TYPE_CHECKING: 16 from sqlglot import exp 17 from sqlglot._typing import A, E, T 18 from sqlglot.dialects.dialect import DialectType 19 from sqlglot.expressions import Expression 20 21 22CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") 23PYTHON_VERSION = sys.version_info[:2] 24logger = logging.getLogger("sqlglot") 25 26 27class AutoName(Enum): 28 """ 29 This is used for creating Enum classes where `auto()` is the string form 30 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 31 32 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 33 """ 34 35 def _generate_next_value_(name, _start, _count, _last_values): 36 return name 37 38 39class classproperty(property): 40 """ 41 Similar to a normal property but works for class methods 42 """ 43 44 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 45 return classmethod(self.fget).__get__(None, owner)() # type: ignore 46 47 48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 49 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 50 try: 51 return seq[index] 52 except IndexError: 53 return None 54 55 56@t.overload 57def ensure_list(value: t.Collection[T]) -> t.List[T]: ... 58 59 60@t.overload 61def ensure_list(value: None) -> t.List: ... 62 63 64@t.overload 65def ensure_list(value: T) -> t.List[T]: ... 66 67 68def ensure_list(value): 69 """ 70 Ensures that a value is a list, otherwise casts or wraps it into one. 71 72 Args: 73 value: The value of interest. 74 75 Returns: 76 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 77 """ 78 if value is None: 79 return [] 80 if isinstance(value, (list, tuple)): 81 return list(value) 82 83 return [value] 84 85 86@t.overload 87def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... 88 89 90@t.overload 91def ensure_collection(value: T) -> t.Collection[T]: ... 92 93 94def ensure_collection(value): 95 """ 96 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 97 98 Args: 99 value: The value of interest. 100 101 Returns: 102 The value if it's a collection, or else the value wrapped in a list. 103 """ 104 if value is None: 105 return [] 106 return ( 107 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 108 ) 109 110 111def csv(*args: str, sep: str = ", ") -> str: 112 """ 113 Formats any number of string arguments as CSV. 114 115 Args: 116 args: The string arguments to format. 117 sep: The argument separator. 118 119 Returns: 120 The arguments formatted as a CSV string. 121 """ 122 return sep.join(arg for arg in args if arg) 123 124 125def subclasses( 126 module_name: str, 127 classes: t.Type | t.Tuple[t.Type, ...], 128 exclude: t.Type | t.Tuple[t.Type, ...] = (), 129) -> t.List[t.Type]: 130 """ 131 Returns all subclasses for a collection of classes, possibly excluding some of them. 132 133 Args: 134 module_name: The name of the module to search for subclasses in. 135 classes: Class(es) we want to find the subclasses of. 136 exclude: Class(es) we want to exclude from the returned list. 137 138 Returns: 139 The target subclasses. 140 """ 141 return [ 142 obj 143 for _, obj in inspect.getmembers( 144 sys.modules[module_name], 145 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 146 ) 147 ] 148 149 150def apply_index_offset( 151 this: exp.Expression, 152 expressions: t.List[E], 153 offset: int, 154 dialect: DialectType = None, 155) -> t.List[E]: 156 """ 157 Applies an offset to a given integer literal expression. 158 159 Args: 160 this: The target of the index. 161 expressions: The expression the offset will be applied to, wrapped in a list. 162 offset: The offset that will be applied. 163 dialect: the dialect of interest. 164 165 Returns: 166 The original expression with the offset applied to it, wrapped in a list. If the provided 167 `expressions` argument contains more than one expression, it's returned unaffected. 168 """ 169 if not offset or len(expressions) != 1: 170 return expressions 171 172 expression = expressions[0] 173 174 from sqlglot import exp 175 from sqlglot.optimizer.annotate_types import annotate_types 176 from sqlglot.optimizer.simplify import simplify 177 178 if not this.type: 179 annotate_types(this, dialect=dialect) 180 181 if t.cast(exp.DataType, this.type).this not in ( 182 exp.DataType.Type.UNKNOWN, 183 exp.DataType.Type.ARRAY, 184 ): 185 return expressions 186 187 if not expression.type: 188 annotate_types(expression, dialect=dialect) 189 190 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 191 logger.info("Applying array index offset (%s)", offset) 192 expression = simplify(expression + offset) 193 return [expression] 194 195 return expressions 196 197 198def camel_to_snake_case(name: str) -> str: 199 """Converts `name` from camelCase to snake_case and returns the result.""" 200 return CAMEL_CASE_PATTERN.sub("_", name).upper() 201 202 203def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 204 """ 205 Applies a transformation to a given expression until a fix point is reached. 206 207 Args: 208 expression: The expression to be transformed. 209 func: The transformation to be applied. 210 211 Returns: 212 The transformed expression. 213 """ 214 end_hash: t.Optional[int] = None 215 216 while True: 217 # No need to walk the AST– we've already cached the hashes in the previous iteration 218 if end_hash is None: 219 for n in reversed(tuple(expression.walk())): 220 n._hash = hash(n) 221 222 start_hash = hash(expression) 223 expression = func(expression) 224 225 expression_nodes = tuple(expression.walk()) 226 227 # Uncache previous caches so we can recompute them 228 for n in reversed(expression_nodes): 229 n._hash = None 230 n._hash = hash(n) 231 232 end_hash = hash(expression) 233 234 if start_hash == end_hash: 235 # ... and reset the hash so we don't risk it becoming out of date if a mutation happens 236 for n in expression_nodes: 237 n._hash = None 238 239 break 240 241 return expression 242 243 244def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 245 """ 246 Sorts a given directed acyclic graph in topological order. 247 248 Args: 249 dag: The graph to be sorted. 250 251 Returns: 252 A list that contains all of the graph's nodes in topological order. 253 """ 254 result = [] 255 256 for node, deps in tuple(dag.items()): 257 for dep in deps: 258 if dep not in dag: 259 dag[dep] = set() 260 261 while dag: 262 current = {node for node, deps in dag.items() if not deps} 263 264 if not current: 265 raise ValueError("Cycle error") 266 267 for node in current: 268 dag.pop(node) 269 270 for deps in dag.values(): 271 deps -= current 272 273 result.extend(sorted(current)) # type: ignore 274 275 return result 276 277 278def open_file(file_name: str) -> t.TextIO: 279 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 280 with open(file_name, "rb") as f: 281 gzipped = f.read(2) == b"\x1f\x8b" 282 283 if gzipped: 284 import gzip 285 286 return gzip.open(file_name, "rt", newline="") 287 288 return open(file_name, encoding="utf-8", newline="") 289 290 291@contextmanager 292def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 293 """ 294 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 295 296 Args: 297 read_csv: A `ReadCSV` function call. 298 299 Yields: 300 A python csv reader. 301 """ 302 args = read_csv.expressions 303 file = open_file(read_csv.name) 304 305 delimiter = "," 306 args = iter(arg.name for arg in args) # type: ignore 307 for k, v in zip(args, args): 308 if k == "delimiter": 309 delimiter = v 310 311 try: 312 import csv as csv_ 313 314 yield csv_.reader(file, delimiter=delimiter) 315 finally: 316 file.close() 317 318 319def find_new_name(taken: t.Collection[str], base: str) -> str: 320 """ 321 Searches for a new name. 322 323 Args: 324 taken: A collection of taken names. 325 base: Base name to alter. 326 327 Returns: 328 The new, available name. 329 """ 330 if base not in taken: 331 return base 332 333 i = 2 334 new = f"{base}_{i}" 335 while new in taken: 336 i += 1 337 new = f"{base}_{i}" 338 339 return new 340 341 342def is_int(text: str) -> bool: 343 return is_type(text, int) 344 345 346def is_float(text: str) -> bool: 347 return is_type(text, float) 348 349 350def is_type(text: str, target_type: t.Type) -> bool: 351 try: 352 target_type(text) 353 return True 354 except ValueError: 355 return False 356 357 358def name_sequence(prefix: str) -> t.Callable[[], str]: 359 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 360 sequence = count() 361 return lambda: f"{prefix}{next(sequence)}" 362 363 364def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 365 """Returns a dictionary created from an object's attributes.""" 366 return { 367 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 368 **kwargs, 369 } 370 371 372def split_num_words( 373 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 374) -> t.List[t.Optional[str]]: 375 """ 376 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 377 378 Args: 379 value: The value to be split. 380 sep: The value to use to split on. 381 min_num_words: The minimum number of words that are going to be in the result. 382 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 383 384 Examples: 385 >>> split_num_words("db.table", ".", 3) 386 [None, 'db', 'table'] 387 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 388 ['db', 'table', None] 389 >>> split_num_words("db.table", ".", 1) 390 ['db', 'table'] 391 392 Returns: 393 The list of words returned by `split`, possibly augmented by a number of `None` values. 394 """ 395 words = value.split(sep) 396 if fill_from_start: 397 return [None] * (min_num_words - len(words)) + words 398 return words + [None] * (min_num_words - len(words)) 399 400 401def is_iterable(value: t.Any) -> bool: 402 """ 403 Checks if the value is an iterable, excluding the types `str` and `bytes`. 404 405 Examples: 406 >>> is_iterable([1,2]) 407 True 408 >>> is_iterable("test") 409 False 410 411 Args: 412 value: The value to check if it is an iterable. 413 414 Returns: 415 A `bool` value indicating if it is an iterable. 416 """ 417 from sqlglot import Expression 418 419 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression)) 420 421 422def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 423 """ 424 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 425 type `str` and `bytes` are not regarded as iterables. 426 427 Examples: 428 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 429 [1, 2, 3, 4, 5, 'bla'] 430 >>> list(flatten([1, 2, 3])) 431 [1, 2, 3] 432 433 Args: 434 values: The value to be flattened. 435 436 Yields: 437 Non-iterable elements in `values`. 438 """ 439 for value in values: 440 if is_iterable(value): 441 yield from flatten(value) 442 else: 443 yield value 444 445 446def dict_depth(d: t.Dict) -> int: 447 """ 448 Get the nesting depth of a dictionary. 449 450 Example: 451 >>> dict_depth(None) 452 0 453 >>> dict_depth({}) 454 1 455 >>> dict_depth({"a": "b"}) 456 1 457 >>> dict_depth({"a": {}}) 458 2 459 >>> dict_depth({"a": {"b": {}}}) 460 3 461 """ 462 try: 463 return 1 + dict_depth(next(iter(d.values()))) 464 except AttributeError: 465 # d doesn't have attribute "values" 466 return 0 467 except StopIteration: 468 # d.values() returns an empty sequence 469 return 1 470 471 472def first(it: t.Iterable[T]) -> T: 473 """Returns the first element from an iterable (useful for sets).""" 474 return next(i for i in it) 475 476 477def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 478 if isinstance(value, bool) or value is None: 479 return value 480 481 # Coerce the value to boolean if it matches to the truthy/falsy values below 482 value_lower = value.lower() 483 if value_lower in ("true", "1"): 484 return True 485 if value_lower in ("false", "0"): 486 return False 487 488 return value 489 490 491def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 492 """ 493 Merges a sequence of ranges, represented as tuples (low, high) whose values 494 belong to some totally-ordered set. 495 496 Example: 497 >>> merge_ranges([(1, 3), (2, 6)]) 498 [(1, 6)] 499 """ 500 if not ranges: 501 return [] 502 503 ranges = sorted(ranges) 504 505 merged = [ranges[0]] 506 507 for start, end in ranges[1:]: 508 last_start, last_end = merged[-1] 509 510 if start <= last_end: 511 merged[-1] = (last_start, max(last_end, end)) 512 else: 513 merged.append((start, end)) 514 515 return merged 516 517 518def is_iso_date(text: str) -> bool: 519 try: 520 datetime.date.fromisoformat(text) 521 return True 522 except ValueError: 523 return False 524 525 526def is_iso_datetime(text: str) -> bool: 527 try: 528 datetime.datetime.fromisoformat(text) 529 return True 530 except ValueError: 531 return False 532 533 534# Interval units that operate on date components 535DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} 536 537 538def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: 539 return expression is not None and expression.name.lower() in DATE_UNITS 540 541 542K = t.TypeVar("K") 543V = t.TypeVar("V") 544 545 546class SingleValuedMapping(t.Mapping[K, V]): 547 """ 548 Mapping where all keys return the same value. 549 550 This rigamarole is meant to avoid copying keys, which was originally intended 551 as an optimization while qualifying columns for tables with lots of columns. 552 """ 553 554 def __init__(self, keys: t.Collection[K], value: V): 555 self._keys = keys if isinstance(keys, Set) else set(keys) 556 self._value = value 557 558 def __getitem__(self, key: K) -> V: 559 if key in self._keys: 560 return self._value 561 raise KeyError(key) 562 563 def __len__(self) -> int: 564 return len(self._keys) 565 566 def __iter__(self) -> t.Iterator[K]: 567 return iter(self._keys)
28class AutoName(Enum): 29 """ 30 This is used for creating Enum classes where `auto()` is the string form 31 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 32 33 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 34 """ 35 36 def _generate_next_value_(name, _start, _count, _last_values): 37 return name
This is used for creating Enum classes where auto()
is the string form
of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
40class classproperty(property): 41 """ 42 Similar to a normal property but works for class methods 43 """ 44 45 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 46 return classmethod(self.fget).__get__(None, owner)() # type: ignore
Similar to a normal property but works for class methods
49def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 50 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 51 try: 52 return seq[index] 53 except IndexError: 54 return None
Returns the value in seq
at position index
, or None
if index
is out of bounds.
69def ensure_list(value): 70 """ 71 Ensures that a value is a list, otherwise casts or wraps it into one. 72 73 Args: 74 value: The value of interest. 75 76 Returns: 77 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 78 """ 79 if value is None: 80 return [] 81 if isinstance(value, (list, tuple)): 82 return list(value) 83 84 return [value]
Ensures that a value is a list, otherwise casts or wraps it into one.
Arguments:
- value: The value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
95def ensure_collection(value): 96 """ 97 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 98 99 Args: 100 value: The value of interest. 101 102 Returns: 103 The value if it's a collection, or else the value wrapped in a list. 104 """ 105 if value is None: 106 return [] 107 return ( 108 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 109 )
Ensures that a value is a collection (excluding str
and bytes
), otherwise wraps it into a list.
Arguments:
- value: The value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
112def csv(*args: str, sep: str = ", ") -> str: 113 """ 114 Formats any number of string arguments as CSV. 115 116 Args: 117 args: The string arguments to format. 118 sep: The argument separator. 119 120 Returns: 121 The arguments formatted as a CSV string. 122 """ 123 return sep.join(arg for arg in args if arg)
Formats any number of string arguments as CSV.
Arguments:
- args: The string arguments to format.
- sep: The argument separator.
Returns:
The arguments formatted as a CSV string.
126def subclasses( 127 module_name: str, 128 classes: t.Type | t.Tuple[t.Type, ...], 129 exclude: t.Type | t.Tuple[t.Type, ...] = (), 130) -> t.List[t.Type]: 131 """ 132 Returns all subclasses for a collection of classes, possibly excluding some of them. 133 134 Args: 135 module_name: The name of the module to search for subclasses in. 136 classes: Class(es) we want to find the subclasses of. 137 exclude: Class(es) we want to exclude from the returned list. 138 139 Returns: 140 The target subclasses. 141 """ 142 return [ 143 obj 144 for _, obj in inspect.getmembers( 145 sys.modules[module_name], 146 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 147 ) 148 ]
Returns all subclasses for a collection of classes, possibly excluding some of them.
Arguments:
- module_name: The name of the module to search for subclasses in.
- classes: Class(es) we want to find the subclasses of.
- exclude: Class(es) we want to exclude from the returned list.
Returns:
The target subclasses.
151def apply_index_offset( 152 this: exp.Expression, 153 expressions: t.List[E], 154 offset: int, 155 dialect: DialectType = None, 156) -> t.List[E]: 157 """ 158 Applies an offset to a given integer literal expression. 159 160 Args: 161 this: The target of the index. 162 expressions: The expression the offset will be applied to, wrapped in a list. 163 offset: The offset that will be applied. 164 dialect: the dialect of interest. 165 166 Returns: 167 The original expression with the offset applied to it, wrapped in a list. If the provided 168 `expressions` argument contains more than one expression, it's returned unaffected. 169 """ 170 if not offset or len(expressions) != 1: 171 return expressions 172 173 expression = expressions[0] 174 175 from sqlglot import exp 176 from sqlglot.optimizer.annotate_types import annotate_types 177 from sqlglot.optimizer.simplify import simplify 178 179 if not this.type: 180 annotate_types(this, dialect=dialect) 181 182 if t.cast(exp.DataType, this.type).this not in ( 183 exp.DataType.Type.UNKNOWN, 184 exp.DataType.Type.ARRAY, 185 ): 186 return expressions 187 188 if not expression.type: 189 annotate_types(expression, dialect=dialect) 190 191 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 192 logger.info("Applying array index offset (%s)", offset) 193 expression = simplify(expression + offset) 194 return [expression] 195 196 return expressions
Applies an offset to a given integer literal expression.
Arguments:
- this: The target of the index.
- expressions: The expression the offset will be applied to, wrapped in a list.
- offset: The offset that will be applied.
- dialect: the dialect of interest.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
expressions
argument contains more than one expression, it's returned unaffected.
199def camel_to_snake_case(name: str) -> str: 200 """Converts `name` from camelCase to snake_case and returns the result.""" 201 return CAMEL_CASE_PATTERN.sub("_", name).upper()
Converts name
from camelCase to snake_case and returns the result.
204def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 205 """ 206 Applies a transformation to a given expression until a fix point is reached. 207 208 Args: 209 expression: The expression to be transformed. 210 func: The transformation to be applied. 211 212 Returns: 213 The transformed expression. 214 """ 215 end_hash: t.Optional[int] = None 216 217 while True: 218 # No need to walk the AST– we've already cached the hashes in the previous iteration 219 if end_hash is None: 220 for n in reversed(tuple(expression.walk())): 221 n._hash = hash(n) 222 223 start_hash = hash(expression) 224 expression = func(expression) 225 226 expression_nodes = tuple(expression.walk()) 227 228 # Uncache previous caches so we can recompute them 229 for n in reversed(expression_nodes): 230 n._hash = None 231 n._hash = hash(n) 232 233 end_hash = hash(expression) 234 235 if start_hash == end_hash: 236 # ... and reset the hash so we don't risk it becoming out of date if a mutation happens 237 for n in expression_nodes: 238 n._hash = None 239 240 break 241 242 return expression
Applies a transformation to a given expression until a fix point is reached.
Arguments:
- expression: The expression to be transformed.
- func: The transformation to be applied.
Returns:
The transformed expression.
245def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 246 """ 247 Sorts a given directed acyclic graph in topological order. 248 249 Args: 250 dag: The graph to be sorted. 251 252 Returns: 253 A list that contains all of the graph's nodes in topological order. 254 """ 255 result = [] 256 257 for node, deps in tuple(dag.items()): 258 for dep in deps: 259 if dep not in dag: 260 dag[dep] = set() 261 262 while dag: 263 current = {node for node, deps in dag.items() if not deps} 264 265 if not current: 266 raise ValueError("Cycle error") 267 268 for node in current: 269 dag.pop(node) 270 271 for deps in dag.values(): 272 deps -= current 273 274 result.extend(sorted(current)) # type: ignore 275 276 return result
Sorts a given directed acyclic graph in topological order.
Arguments:
- dag: The graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
279def open_file(file_name: str) -> t.TextIO: 280 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 281 with open(file_name, "rb") as f: 282 gzipped = f.read(2) == b"\x1f\x8b" 283 284 if gzipped: 285 import gzip 286 287 return gzip.open(file_name, "rt", newline="") 288 289 return open(file_name, encoding="utf-8", newline="")
Open a file that may be compressed as gzip and return it in universal newline mode.
292@contextmanager 293def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 294 """ 295 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 296 297 Args: 298 read_csv: A `ReadCSV` function call. 299 300 Yields: 301 A python csv reader. 302 """ 303 args = read_csv.expressions 304 file = open_file(read_csv.name) 305 306 delimiter = "," 307 args = iter(arg.name for arg in args) # type: ignore 308 for k, v in zip(args, args): 309 if k == "delimiter": 310 delimiter = v 311 312 try: 313 import csv as csv_ 314 315 yield csv_.reader(file, delimiter=delimiter) 316 finally: 317 file.close()
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
.
Arguments:
- read_csv: A
ReadCSV
function call.
Yields:
A python csv reader.
320def find_new_name(taken: t.Collection[str], base: str) -> str: 321 """ 322 Searches for a new name. 323 324 Args: 325 taken: A collection of taken names. 326 base: Base name to alter. 327 328 Returns: 329 The new, available name. 330 """ 331 if base not in taken: 332 return base 333 334 i = 2 335 new = f"{base}_{i}" 336 while new in taken: 337 i += 1 338 new = f"{base}_{i}" 339 340 return new
Searches for a new name.
Arguments:
- taken: A collection of taken names.
- base: Base name to alter.
Returns:
The new, available name.
359def name_sequence(prefix: str) -> t.Callable[[], str]: 360 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 361 sequence = count() 362 return lambda: f"{prefix}{next(sequence)}"
Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").
365def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 366 """Returns a dictionary created from an object's attributes.""" 367 return { 368 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 369 **kwargs, 370 }
Returns a dictionary created from an object's attributes.
373def split_num_words( 374 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 375) -> t.List[t.Optional[str]]: 376 """ 377 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 378 379 Args: 380 value: The value to be split. 381 sep: The value to use to split on. 382 min_num_words: The minimum number of words that are going to be in the result. 383 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 384 385 Examples: 386 >>> split_num_words("db.table", ".", 3) 387 [None, 'db', 'table'] 388 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 389 ['db', 'table', None] 390 >>> split_num_words("db.table", ".", 1) 391 ['db', 'table'] 392 393 Returns: 394 The list of words returned by `split`, possibly augmented by a number of `None` values. 395 """ 396 words = value.split(sep) 397 if fill_from_start: 398 return [None] * (min_num_words - len(words)) + words 399 return words + [None] * (min_num_words - len(words))
Perform a split on a value and return N words as a result with None
used for words that don't exist.
Arguments:
- value: The value to be split.
- sep: The value to use to split on.
- min_num_words: The minimum number of words that are going to be in the result.
- fill_from_start: Indicates that if
None
values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3) [None, 'db', 'table'] >>> split_num_words("db.table", ".", 3, fill_from_start=False) ['db', 'table', None] >>> split_num_words("db.table", ".", 1) ['db', 'table']
Returns:
The list of words returned by
split
, possibly augmented by a number ofNone
values.
402def is_iterable(value: t.Any) -> bool: 403 """ 404 Checks if the value is an iterable, excluding the types `str` and `bytes`. 405 406 Examples: 407 >>> is_iterable([1,2]) 408 True 409 >>> is_iterable("test") 410 False 411 412 Args: 413 value: The value to check if it is an iterable. 414 415 Returns: 416 A `bool` value indicating if it is an iterable. 417 """ 418 from sqlglot import Expression 419 420 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
Checks if the value is an iterable, excluding the types str
and bytes
.
Examples:
>>> is_iterable([1,2]) True >>> is_iterable("test") False
Arguments:
- value: The value to check if it is an iterable.
Returns:
A
bool
value indicating if it is an iterable.
423def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 424 """ 425 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 426 type `str` and `bytes` are not regarded as iterables. 427 428 Examples: 429 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 430 [1, 2, 3, 4, 5, 'bla'] 431 >>> list(flatten([1, 2, 3])) 432 [1, 2, 3] 433 434 Args: 435 values: The value to be flattened. 436 437 Yields: 438 Non-iterable elements in `values`. 439 """ 440 for value in values: 441 if is_iterable(value): 442 yield from flatten(value) 443 else: 444 yield value
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type str
and bytes
are not regarded as iterables.
Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) [1, 2, 3, 4, 5, 'bla'] >>> list(flatten([1, 2, 3])) [1, 2, 3]
Arguments:
- values: The value to be flattened.
Yields:
Non-iterable elements in
values
.
447def dict_depth(d: t.Dict) -> int: 448 """ 449 Get the nesting depth of a dictionary. 450 451 Example: 452 >>> dict_depth(None) 453 0 454 >>> dict_depth({}) 455 1 456 >>> dict_depth({"a": "b"}) 457 1 458 >>> dict_depth({"a": {}}) 459 2 460 >>> dict_depth({"a": {"b": {}}}) 461 3 462 """ 463 try: 464 return 1 + dict_depth(next(iter(d.values()))) 465 except AttributeError: 466 # d doesn't have attribute "values" 467 return 0 468 except StopIteration: 469 # d.values() returns an empty sequence 470 return 1
Get the nesting depth of a dictionary.
Example:
>>> dict_depth(None) 0 >>> dict_depth({}) 1 >>> dict_depth({"a": "b"}) 1 >>> dict_depth({"a": {}}) 2 >>> dict_depth({"a": {"b": {}}}) 3
473def first(it: t.Iterable[T]) -> T: 474 """Returns the first element from an iterable (useful for sets).""" 475 return next(i for i in it)
Returns the first element from an iterable (useful for sets).
478def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 479 if isinstance(value, bool) or value is None: 480 return value 481 482 # Coerce the value to boolean if it matches to the truthy/falsy values below 483 value_lower = value.lower() 484 if value_lower in ("true", "1"): 485 return True 486 if value_lower in ("false", "0"): 487 return False 488 489 return value
492def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 493 """ 494 Merges a sequence of ranges, represented as tuples (low, high) whose values 495 belong to some totally-ordered set. 496 497 Example: 498 >>> merge_ranges([(1, 3), (2, 6)]) 499 [(1, 6)] 500 """ 501 if not ranges: 502 return [] 503 504 ranges = sorted(ranges) 505 506 merged = [ranges[0]] 507 508 for start, end in ranges[1:]: 509 last_start, last_end = merged[-1] 510 511 if start <= last_end: 512 merged[-1] = (last_start, max(last_end, end)) 513 else: 514 merged.append((start, end)) 515 516 return merged
Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.
Example:
>>> merge_ranges([(1, 3), (2, 6)]) [(1, 6)]
547class SingleValuedMapping(t.Mapping[K, V]): 548 """ 549 Mapping where all keys return the same value. 550 551 This rigamarole is meant to avoid copying keys, which was originally intended 552 as an optimization while qualifying columns for tables with lots of columns. 553 """ 554 555 def __init__(self, keys: t.Collection[K], value: V): 556 self._keys = keys if isinstance(keys, Set) else set(keys) 557 self._value = value 558 559 def __getitem__(self, key: K) -> V: 560 if key in self._keys: 561 return self._value 562 raise KeyError(key) 563 564 def __len__(self) -> int: 565 return len(self._keys) 566 567 def __iter__(self) -> t.Iterator[K]: 568 return iter(self._keys)
Mapping where all keys return the same value.
This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.