sqlglot.helper
1from __future__ import annotations 2 3import datetime 4import inspect 5import logging 6import re 7import sys 8import typing as t 9from collections.abc import Collection, Set 10from contextlib import contextmanager 11from copy import copy 12from enum import Enum 13from itertools import count 14 15if t.TYPE_CHECKING: 16 from sqlglot import exp 17 from sqlglot._typing import A, E, T 18 from sqlglot.expressions import Expression 19 20 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") 22PYTHON_VERSION = sys.version_info[:2] 23logger = logging.getLogger("sqlglot") 24 25 26class AutoName(Enum): 27 """ 28 This is used for creating Enum classes where `auto()` is the string form 29 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 30 31 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 32 """ 33 34 def _generate_next_value_(name, _start, _count, _last_values): 35 return name 36 37 38class classproperty(property): 39 """ 40 Similar to a normal property but works for class methods 41 """ 42 43 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 44 return classmethod(self.fget).__get__(None, owner)() # type: ignore 45 46 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 48 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 49 try: 50 return seq[index] 51 except IndexError: 52 return None 53 54 55@t.overload 56def ensure_list(value: t.Collection[T]) -> t.List[T]: ... 57 58 59@t.overload 60def ensure_list(value: None) -> t.List: ... 61 62 63@t.overload 64def ensure_list(value: T) -> t.List[T]: ... 65 66 67def ensure_list(value): 68 """ 69 Ensures that a value is a list, otherwise casts or wraps it into one. 70 71 Args: 72 value: The value of interest. 73 74 Returns: 75 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 76 """ 77 if value is None: 78 return [] 79 if isinstance(value, (list, tuple)): 80 return list(value) 81 82 return [value] 83 84 85@t.overload 86def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... 87 88 89@t.overload 90def ensure_collection(value: T) -> t.Collection[T]: ... 91 92 93def ensure_collection(value): 94 """ 95 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 96 97 Args: 98 value: The value of interest. 99 100 Returns: 101 The value if it's a collection, or else the value wrapped in a list. 102 """ 103 if value is None: 104 return [] 105 return ( 106 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 107 ) 108 109 110def csv(*args: str, sep: str = ", ") -> str: 111 """ 112 Formats any number of string arguments as CSV. 113 114 Args: 115 args: The string arguments to format. 116 sep: The argument separator. 117 118 Returns: 119 The arguments formatted as a CSV string. 120 """ 121 return sep.join(arg for arg in args if arg) 122 123 124def subclasses( 125 module_name: str, 126 classes: t.Type | t.Tuple[t.Type, ...], 127 exclude: t.Type | t.Tuple[t.Type, ...] = (), 128) -> t.List[t.Type]: 129 """ 130 Returns all subclasses for a collection of classes, possibly excluding some of them. 131 132 Args: 133 module_name: The name of the module to search for subclasses in. 134 classes: Class(es) we want to find the subclasses of. 135 exclude: Class(es) we want to exclude from the returned list. 136 137 Returns: 138 The target subclasses. 139 """ 140 return [ 141 obj 142 for _, obj in inspect.getmembers( 143 sys.modules[module_name], 144 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 145 ) 146 ] 147 148 149def apply_index_offset( 150 this: exp.Expression, 151 expressions: t.List[E], 152 offset: int, 153) -> t.List[E]: 154 """ 155 Applies an offset to a given integer literal expression. 156 157 Args: 158 this: The target of the index. 159 expressions: The expression the offset will be applied to, wrapped in a list. 160 offset: The offset that will be applied. 161 162 Returns: 163 The original expression with the offset applied to it, wrapped in a list. If the provided 164 `expressions` argument contains more than one expression, it's returned unaffected. 165 """ 166 if not offset or len(expressions) != 1: 167 return expressions 168 169 expression = expressions[0] 170 171 from sqlglot import exp 172 from sqlglot.optimizer.annotate_types import annotate_types 173 from sqlglot.optimizer.simplify import simplify 174 175 if not this.type: 176 annotate_types(this) 177 178 if t.cast(exp.DataType, this.type).this not in ( 179 exp.DataType.Type.UNKNOWN, 180 exp.DataType.Type.ARRAY, 181 ): 182 return expressions 183 184 if not expression.type: 185 annotate_types(expression) 186 187 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 188 logger.info("Applying array index offset (%s)", offset) 189 expression = simplify(expression + offset) 190 return [expression] 191 192 return expressions 193 194 195def camel_to_snake_case(name: str) -> str: 196 """Converts `name` from camelCase to snake_case and returns the result.""" 197 return CAMEL_CASE_PATTERN.sub("_", name).upper() 198 199 200def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 201 """ 202 Applies a transformation to a given expression until a fix point is reached. 203 204 Args: 205 expression: The expression to be transformed. 206 func: The transformation to be applied. 207 208 Returns: 209 The transformed expression. 210 """ 211 while True: 212 for n in reversed(tuple(expression.walk())): 213 n._hash = hash(n) 214 215 start = hash(expression) 216 expression = func(expression) 217 218 for n in expression.walk(): 219 n._hash = None 220 if start == hash(expression): 221 break 222 223 return expression 224 225 226def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 227 """ 228 Sorts a given directed acyclic graph in topological order. 229 230 Args: 231 dag: The graph to be sorted. 232 233 Returns: 234 A list that contains all of the graph's nodes in topological order. 235 """ 236 result = [] 237 238 for node, deps in tuple(dag.items()): 239 for dep in deps: 240 if dep not in dag: 241 dag[dep] = set() 242 243 while dag: 244 current = {node for node, deps in dag.items() if not deps} 245 246 if not current: 247 raise ValueError("Cycle error") 248 249 for node in current: 250 dag.pop(node) 251 252 for deps in dag.values(): 253 deps -= current 254 255 result.extend(sorted(current)) # type: ignore 256 257 return result 258 259 260def open_file(file_name: str) -> t.TextIO: 261 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 262 with open(file_name, "rb") as f: 263 gzipped = f.read(2) == b"\x1f\x8b" 264 265 if gzipped: 266 import gzip 267 268 return gzip.open(file_name, "rt", newline="") 269 270 return open(file_name, encoding="utf-8", newline="") 271 272 273@contextmanager 274def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 275 """ 276 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 277 278 Args: 279 read_csv: A `ReadCSV` function call. 280 281 Yields: 282 A python csv reader. 283 """ 284 args = read_csv.expressions 285 file = open_file(read_csv.name) 286 287 delimiter = "," 288 args = iter(arg.name for arg in args) # type: ignore 289 for k, v in zip(args, args): 290 if k == "delimiter": 291 delimiter = v 292 293 try: 294 import csv as csv_ 295 296 yield csv_.reader(file, delimiter=delimiter) 297 finally: 298 file.close() 299 300 301def find_new_name(taken: t.Collection[str], base: str) -> str: 302 """ 303 Searches for a new name. 304 305 Args: 306 taken: A collection of taken names. 307 base: Base name to alter. 308 309 Returns: 310 The new, available name. 311 """ 312 if base not in taken: 313 return base 314 315 i = 2 316 new = f"{base}_{i}" 317 while new in taken: 318 i += 1 319 new = f"{base}_{i}" 320 321 return new 322 323 324def is_int(text: str) -> bool: 325 return is_type(text, int) 326 327 328def is_float(text: str) -> bool: 329 return is_type(text, float) 330 331 332def is_type(text: str, target_type: t.Type) -> bool: 333 try: 334 target_type(text) 335 return True 336 except ValueError: 337 return False 338 339 340def name_sequence(prefix: str) -> t.Callable[[], str]: 341 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 342 sequence = count() 343 return lambda: f"{prefix}{next(sequence)}" 344 345 346def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 347 """Returns a dictionary created from an object's attributes.""" 348 return { 349 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 350 **kwargs, 351 } 352 353 354def split_num_words( 355 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 356) -> t.List[t.Optional[str]]: 357 """ 358 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 359 360 Args: 361 value: The value to be split. 362 sep: The value to use to split on. 363 min_num_words: The minimum number of words that are going to be in the result. 364 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 365 366 Examples: 367 >>> split_num_words("db.table", ".", 3) 368 [None, 'db', 'table'] 369 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 370 ['db', 'table', None] 371 >>> split_num_words("db.table", ".", 1) 372 ['db', 'table'] 373 374 Returns: 375 The list of words returned by `split`, possibly augmented by a number of `None` values. 376 """ 377 words = value.split(sep) 378 if fill_from_start: 379 return [None] * (min_num_words - len(words)) + words 380 return words + [None] * (min_num_words - len(words)) 381 382 383def is_iterable(value: t.Any) -> bool: 384 """ 385 Checks if the value is an iterable, excluding the types `str` and `bytes`. 386 387 Examples: 388 >>> is_iterable([1,2]) 389 True 390 >>> is_iterable("test") 391 False 392 393 Args: 394 value: The value to check if it is an iterable. 395 396 Returns: 397 A `bool` value indicating if it is an iterable. 398 """ 399 from sqlglot import Expression 400 401 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression)) 402 403 404def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 405 """ 406 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 407 type `str` and `bytes` are not regarded as iterables. 408 409 Examples: 410 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 411 [1, 2, 3, 4, 5, 'bla'] 412 >>> list(flatten([1, 2, 3])) 413 [1, 2, 3] 414 415 Args: 416 values: The value to be flattened. 417 418 Yields: 419 Non-iterable elements in `values`. 420 """ 421 for value in values: 422 if is_iterable(value): 423 yield from flatten(value) 424 else: 425 yield value 426 427 428def dict_depth(d: t.Dict) -> int: 429 """ 430 Get the nesting depth of a dictionary. 431 432 Example: 433 >>> dict_depth(None) 434 0 435 >>> dict_depth({}) 436 1 437 >>> dict_depth({"a": "b"}) 438 1 439 >>> dict_depth({"a": {}}) 440 2 441 >>> dict_depth({"a": {"b": {}}}) 442 3 443 """ 444 try: 445 return 1 + dict_depth(next(iter(d.values()))) 446 except AttributeError: 447 # d doesn't have attribute "values" 448 return 0 449 except StopIteration: 450 # d.values() returns an empty sequence 451 return 1 452 453 454def first(it: t.Iterable[T]) -> T: 455 """Returns the first element from an iterable (useful for sets).""" 456 return next(i for i in it) 457 458 459def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 460 if isinstance(value, bool) or value is None: 461 return value 462 463 # Coerce the value to boolean if it matches to the truthy/falsy values below 464 value_lower = value.lower() 465 if value_lower in ("true", "1"): 466 return True 467 if value_lower in ("false", "0"): 468 return False 469 470 return value 471 472 473def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 474 """ 475 Merges a sequence of ranges, represented as tuples (low, high) whose values 476 belong to some totally-ordered set. 477 478 Example: 479 >>> merge_ranges([(1, 3), (2, 6)]) 480 [(1, 6)] 481 """ 482 if not ranges: 483 return [] 484 485 ranges = sorted(ranges) 486 487 merged = [ranges[0]] 488 489 for start, end in ranges[1:]: 490 last_start, last_end = merged[-1] 491 492 if start <= last_end: 493 merged[-1] = (last_start, max(last_end, end)) 494 else: 495 merged.append((start, end)) 496 497 return merged 498 499 500def is_iso_date(text: str) -> bool: 501 try: 502 datetime.date.fromisoformat(text) 503 return True 504 except ValueError: 505 return False 506 507 508def is_iso_datetime(text: str) -> bool: 509 try: 510 datetime.datetime.fromisoformat(text) 511 return True 512 except ValueError: 513 return False 514 515 516# Interval units that operate on date components 517DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} 518 519 520def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: 521 return expression is not None and expression.name.lower() in DATE_UNITS 522 523 524K = t.TypeVar("K") 525V = t.TypeVar("V") 526 527 528class SingleValuedMapping(t.Mapping[K, V]): 529 """ 530 Mapping where all keys return the same value. 531 532 This rigamarole is meant to avoid copying keys, which was originally intended 533 as an optimization while qualifying columns for tables with lots of columns. 534 """ 535 536 def __init__(self, keys: t.Collection[K], value: V): 537 self._keys = keys if isinstance(keys, Set) else set(keys) 538 self._value = value 539 540 def __getitem__(self, key: K) -> V: 541 if key in self._keys: 542 return self._value 543 raise KeyError(key) 544 545 def __len__(self) -> int: 546 return len(self._keys) 547 548 def __iter__(self) -> t.Iterator[K]: 549 return iter(self._keys)
27class AutoName(Enum): 28 """ 29 This is used for creating Enum classes where `auto()` is the string form 30 of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). 31 32 Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values 33 """ 34 35 def _generate_next_value_(name, _start, _count, _last_values): 36 return name
This is used for creating Enum classes where auto()
is the string form
of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
39class classproperty(property): 40 """ 41 Similar to a normal property but works for class methods 42 """ 43 44 def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: 45 return classmethod(self.fget).__get__(None, owner)() # type: ignore
Similar to a normal property but works for class methods
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: 49 """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" 50 try: 51 return seq[index] 52 except IndexError: 53 return None
Returns the value in seq
at position index
, or None
if index
is out of bounds.
68def ensure_list(value): 69 """ 70 Ensures that a value is a list, otherwise casts or wraps it into one. 71 72 Args: 73 value: The value of interest. 74 75 Returns: 76 The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. 77 """ 78 if value is None: 79 return [] 80 if isinstance(value, (list, tuple)): 81 return list(value) 82 83 return [value]
Ensures that a value is a list, otherwise casts or wraps it into one.
Arguments:
- value: The value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
94def ensure_collection(value): 95 """ 96 Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. 97 98 Args: 99 value: The value of interest. 100 101 Returns: 102 The value if it's a collection, or else the value wrapped in a list. 103 """ 104 if value is None: 105 return [] 106 return ( 107 value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] 108 )
Ensures that a value is a collection (excluding str
and bytes
), otherwise wraps it into a list.
Arguments:
- value: The value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
111def csv(*args: str, sep: str = ", ") -> str: 112 """ 113 Formats any number of string arguments as CSV. 114 115 Args: 116 args: The string arguments to format. 117 sep: The argument separator. 118 119 Returns: 120 The arguments formatted as a CSV string. 121 """ 122 return sep.join(arg for arg in args if arg)
Formats any number of string arguments as CSV.
Arguments:
- args: The string arguments to format.
- sep: The argument separator.
Returns:
The arguments formatted as a CSV string.
125def subclasses( 126 module_name: str, 127 classes: t.Type | t.Tuple[t.Type, ...], 128 exclude: t.Type | t.Tuple[t.Type, ...] = (), 129) -> t.List[t.Type]: 130 """ 131 Returns all subclasses for a collection of classes, possibly excluding some of them. 132 133 Args: 134 module_name: The name of the module to search for subclasses in. 135 classes: Class(es) we want to find the subclasses of. 136 exclude: Class(es) we want to exclude from the returned list. 137 138 Returns: 139 The target subclasses. 140 """ 141 return [ 142 obj 143 for _, obj in inspect.getmembers( 144 sys.modules[module_name], 145 lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, 146 ) 147 ]
Returns all subclasses for a collection of classes, possibly excluding some of them.
Arguments:
- module_name: The name of the module to search for subclasses in.
- classes: Class(es) we want to find the subclasses of.
- exclude: Class(es) we want to exclude from the returned list.
Returns:
The target subclasses.
150def apply_index_offset( 151 this: exp.Expression, 152 expressions: t.List[E], 153 offset: int, 154) -> t.List[E]: 155 """ 156 Applies an offset to a given integer literal expression. 157 158 Args: 159 this: The target of the index. 160 expressions: The expression the offset will be applied to, wrapped in a list. 161 offset: The offset that will be applied. 162 163 Returns: 164 The original expression with the offset applied to it, wrapped in a list. If the provided 165 `expressions` argument contains more than one expression, it's returned unaffected. 166 """ 167 if not offset or len(expressions) != 1: 168 return expressions 169 170 expression = expressions[0] 171 172 from sqlglot import exp 173 from sqlglot.optimizer.annotate_types import annotate_types 174 from sqlglot.optimizer.simplify import simplify 175 176 if not this.type: 177 annotate_types(this) 178 179 if t.cast(exp.DataType, this.type).this not in ( 180 exp.DataType.Type.UNKNOWN, 181 exp.DataType.Type.ARRAY, 182 ): 183 return expressions 184 185 if not expression.type: 186 annotate_types(expression) 187 188 if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: 189 logger.info("Applying array index offset (%s)", offset) 190 expression = simplify(expression + offset) 191 return [expression] 192 193 return expressions
Applies an offset to a given integer literal expression.
Arguments:
- this: The target of the index.
- expressions: The expression the offset will be applied to, wrapped in a list.
- offset: The offset that will be applied.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
expressions
argument contains more than one expression, it's returned unaffected.
196def camel_to_snake_case(name: str) -> str: 197 """Converts `name` from camelCase to snake_case and returns the result.""" 198 return CAMEL_CASE_PATTERN.sub("_", name).upper()
Converts name
from camelCase to snake_case and returns the result.
201def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: 202 """ 203 Applies a transformation to a given expression until a fix point is reached. 204 205 Args: 206 expression: The expression to be transformed. 207 func: The transformation to be applied. 208 209 Returns: 210 The transformed expression. 211 """ 212 while True: 213 for n in reversed(tuple(expression.walk())): 214 n._hash = hash(n) 215 216 start = hash(expression) 217 expression = func(expression) 218 219 for n in expression.walk(): 220 n._hash = None 221 if start == hash(expression): 222 break 223 224 return expression
Applies a transformation to a given expression until a fix point is reached.
Arguments:
- expression: The expression to be transformed.
- func: The transformation to be applied.
Returns:
The transformed expression.
227def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: 228 """ 229 Sorts a given directed acyclic graph in topological order. 230 231 Args: 232 dag: The graph to be sorted. 233 234 Returns: 235 A list that contains all of the graph's nodes in topological order. 236 """ 237 result = [] 238 239 for node, deps in tuple(dag.items()): 240 for dep in deps: 241 if dep not in dag: 242 dag[dep] = set() 243 244 while dag: 245 current = {node for node, deps in dag.items() if not deps} 246 247 if not current: 248 raise ValueError("Cycle error") 249 250 for node in current: 251 dag.pop(node) 252 253 for deps in dag.values(): 254 deps -= current 255 256 result.extend(sorted(current)) # type: ignore 257 258 return result
Sorts a given directed acyclic graph in topological order.
Arguments:
- dag: The graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
261def open_file(file_name: str) -> t.TextIO: 262 """Open a file that may be compressed as gzip and return it in universal newline mode.""" 263 with open(file_name, "rb") as f: 264 gzipped = f.read(2) == b"\x1f\x8b" 265 266 if gzipped: 267 import gzip 268 269 return gzip.open(file_name, "rt", newline="") 270 271 return open(file_name, encoding="utf-8", newline="")
Open a file that may be compressed as gzip and return it in universal newline mode.
274@contextmanager 275def csv_reader(read_csv: exp.ReadCSV) -> t.Any: 276 """ 277 Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. 278 279 Args: 280 read_csv: A `ReadCSV` function call. 281 282 Yields: 283 A python csv reader. 284 """ 285 args = read_csv.expressions 286 file = open_file(read_csv.name) 287 288 delimiter = "," 289 args = iter(arg.name for arg in args) # type: ignore 290 for k, v in zip(args, args): 291 if k == "delimiter": 292 delimiter = v 293 294 try: 295 import csv as csv_ 296 297 yield csv_.reader(file, delimiter=delimiter) 298 finally: 299 file.close()
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
.
Arguments:
- read_csv: A
ReadCSV
function call.
Yields:
A python csv reader.
302def find_new_name(taken: t.Collection[str], base: str) -> str: 303 """ 304 Searches for a new name. 305 306 Args: 307 taken: A collection of taken names. 308 base: Base name to alter. 309 310 Returns: 311 The new, available name. 312 """ 313 if base not in taken: 314 return base 315 316 i = 2 317 new = f"{base}_{i}" 318 while new in taken: 319 i += 1 320 new = f"{base}_{i}" 321 322 return new
Searches for a new name.
Arguments:
- taken: A collection of taken names.
- base: Base name to alter.
Returns:
The new, available name.
341def name_sequence(prefix: str) -> t.Callable[[], str]: 342 """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" 343 sequence = count() 344 return lambda: f"{prefix}{next(sequence)}"
Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").
347def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: 348 """Returns a dictionary created from an object's attributes.""" 349 return { 350 **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, 351 **kwargs, 352 }
Returns a dictionary created from an object's attributes.
355def split_num_words( 356 value: str, sep: str, min_num_words: int, fill_from_start: bool = True 357) -> t.List[t.Optional[str]]: 358 """ 359 Perform a split on a value and return N words as a result with `None` used for words that don't exist. 360 361 Args: 362 value: The value to be split. 363 sep: The value to use to split on. 364 min_num_words: The minimum number of words that are going to be in the result. 365 fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. 366 367 Examples: 368 >>> split_num_words("db.table", ".", 3) 369 [None, 'db', 'table'] 370 >>> split_num_words("db.table", ".", 3, fill_from_start=False) 371 ['db', 'table', None] 372 >>> split_num_words("db.table", ".", 1) 373 ['db', 'table'] 374 375 Returns: 376 The list of words returned by `split`, possibly augmented by a number of `None` values. 377 """ 378 words = value.split(sep) 379 if fill_from_start: 380 return [None] * (min_num_words - len(words)) + words 381 return words + [None] * (min_num_words - len(words))
Perform a split on a value and return N words as a result with None
used for words that don't exist.
Arguments:
- value: The value to be split.
- sep: The value to use to split on.
- min_num_words: The minimum number of words that are going to be in the result.
- fill_from_start: Indicates that if
None
values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3) [None, 'db', 'table'] >>> split_num_words("db.table", ".", 3, fill_from_start=False) ['db', 'table', None] >>> split_num_words("db.table", ".", 1) ['db', 'table']
Returns:
The list of words returned by
split
, possibly augmented by a number ofNone
values.
384def is_iterable(value: t.Any) -> bool: 385 """ 386 Checks if the value is an iterable, excluding the types `str` and `bytes`. 387 388 Examples: 389 >>> is_iterable([1,2]) 390 True 391 >>> is_iterable("test") 392 False 393 394 Args: 395 value: The value to check if it is an iterable. 396 397 Returns: 398 A `bool` value indicating if it is an iterable. 399 """ 400 from sqlglot import Expression 401 402 return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
Checks if the value is an iterable, excluding the types str
and bytes
.
Examples:
>>> is_iterable([1,2]) True >>> is_iterable("test") False
Arguments:
- value: The value to check if it is an iterable.
Returns:
A
bool
value indicating if it is an iterable.
405def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: 406 """ 407 Flattens an iterable that can contain both iterable and non-iterable elements. Objects of 408 type `str` and `bytes` are not regarded as iterables. 409 410 Examples: 411 >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) 412 [1, 2, 3, 4, 5, 'bla'] 413 >>> list(flatten([1, 2, 3])) 414 [1, 2, 3] 415 416 Args: 417 values: The value to be flattened. 418 419 Yields: 420 Non-iterable elements in `values`. 421 """ 422 for value in values: 423 if is_iterable(value): 424 yield from flatten(value) 425 else: 426 yield value
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type str
and bytes
are not regarded as iterables.
Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) [1, 2, 3, 4, 5, 'bla'] >>> list(flatten([1, 2, 3])) [1, 2, 3]
Arguments:
- values: The value to be flattened.
Yields:
Non-iterable elements in
values
.
429def dict_depth(d: t.Dict) -> int: 430 """ 431 Get the nesting depth of a dictionary. 432 433 Example: 434 >>> dict_depth(None) 435 0 436 >>> dict_depth({}) 437 1 438 >>> dict_depth({"a": "b"}) 439 1 440 >>> dict_depth({"a": {}}) 441 2 442 >>> dict_depth({"a": {"b": {}}}) 443 3 444 """ 445 try: 446 return 1 + dict_depth(next(iter(d.values()))) 447 except AttributeError: 448 # d doesn't have attribute "values" 449 return 0 450 except StopIteration: 451 # d.values() returns an empty sequence 452 return 1
Get the nesting depth of a dictionary.
Example:
>>> dict_depth(None) 0 >>> dict_depth({}) 1 >>> dict_depth({"a": "b"}) 1 >>> dict_depth({"a": {}}) 2 >>> dict_depth({"a": {"b": {}}}) 3
455def first(it: t.Iterable[T]) -> T: 456 """Returns the first element from an iterable (useful for sets).""" 457 return next(i for i in it)
Returns the first element from an iterable (useful for sets).
460def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: 461 if isinstance(value, bool) or value is None: 462 return value 463 464 # Coerce the value to boolean if it matches to the truthy/falsy values below 465 value_lower = value.lower() 466 if value_lower in ("true", "1"): 467 return True 468 if value_lower in ("false", "0"): 469 return False 470 471 return value
474def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: 475 """ 476 Merges a sequence of ranges, represented as tuples (low, high) whose values 477 belong to some totally-ordered set. 478 479 Example: 480 >>> merge_ranges([(1, 3), (2, 6)]) 481 [(1, 6)] 482 """ 483 if not ranges: 484 return [] 485 486 ranges = sorted(ranges) 487 488 merged = [ranges[0]] 489 490 for start, end in ranges[1:]: 491 last_start, last_end = merged[-1] 492 493 if start <= last_end: 494 merged[-1] = (last_start, max(last_end, end)) 495 else: 496 merged.append((start, end)) 497 498 return merged
Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.
Example:
>>> merge_ranges([(1, 3), (2, 6)]) [(1, 6)]
529class SingleValuedMapping(t.Mapping[K, V]): 530 """ 531 Mapping where all keys return the same value. 532 533 This rigamarole is meant to avoid copying keys, which was originally intended 534 as an optimization while qualifying columns for tables with lots of columns. 535 """ 536 537 def __init__(self, keys: t.Collection[K], value: V): 538 self._keys = keys if isinstance(keys, Set) else set(keys) 539 self._value = value 540 541 def __getitem__(self, key: K) -> V: 542 if key in self._keys: 543 return self._value 544 raise KeyError(key) 545 546 def __len__(self) -> int: 547 return len(self._keys) 548 549 def __iter__(self) -> t.Iterator[K]: 550 return iter(self._keys)
Mapping where all keys return the same value.
This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.