sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12from sqlglot.helper import trait 13 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import SchemaArgs 17 from sqlglot.dialects.dialect import DialectType 18 from collections.abc import Sequence 19 from typing_extensions import Unpack 20 21 ColumnMapping = t.Union[dict[str, t.Any], str, list[str]] 22 23 24@trait 25class Schema(abc.ABC): 26 """Abstract base class for database schemas""" 27 28 @property 29 def dialect(self) -> Dialect | None: 30 """ 31 Returns None by default. Subclasses that require dialect-specific 32 behavior should override this property. 33 """ 34 return None 35 36 @abc.abstractmethod 37 def add_table( 38 self, 39 table: exp.Table | str, 40 column_mapping: ColumnMapping | None = None, 41 dialect: DialectType = None, 42 normalize: bool | None = None, 43 match_depth: bool = True, 44 ) -> None: 45 """ 46 Register or update a table. Some implementing classes may require column information to also be provided. 47 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 48 49 Args: 50 table: the `Table` expression instance or string representing the table. 51 column_mapping: a column mapping that describes the structure of the table. 52 dialect: the SQL dialect that will be used to parse `table` if it's a string. 53 normalize: whether to normalize identifiers according to the dialect of interest. 54 match_depth: whether to enforce that the table must match the schema's depth or not. 55 """ 56 57 @abc.abstractmethod 58 def column_names( 59 self, 60 table: exp.Table | str, 61 only_visible: bool = False, 62 dialect: DialectType = None, 63 normalize: bool | None = None, 64 ) -> Sequence[str]: 65 """ 66 Get the column names for a table. 67 68 Args: 69 table: the `Table` expression instance. 70 only_visible: whether to include invisible columns. 71 dialect: the SQL dialect that will be used to parse `table` if it's a string. 72 normalize: whether to normalize identifiers according to the dialect of interest. 73 74 Returns: 75 The sequence of column names. 76 """ 77 78 @abc.abstractmethod 79 def get_column_type( 80 self, 81 table: exp.Table | str, 82 column: exp.Column | str, 83 dialect: DialectType = None, 84 normalize: bool | None = None, 85 ) -> exp.DataType: 86 """ 87 Get the `sqlglot.exp.DataType` type of a column in the schema. 88 89 Args: 90 table: the source table. 91 column: the target column. 92 dialect: the SQL dialect that will be used to parse `table` if it's a string. 93 normalize: whether to normalize identifiers according to the dialect of interest. 94 95 Returns: 96 The resulting column type. 97 """ 98 99 def has_column( 100 self, 101 table: exp.Table | str, 102 column: exp.Column | str, 103 dialect: DialectType = None, 104 normalize: bool | None = None, 105 ) -> bool: 106 """ 107 Returns whether `column` appears in `table`'s schema. 108 109 Args: 110 table: the source table. 111 column: the target column. 112 dialect: the SQL dialect that will be used to parse `table` if it's a string. 113 normalize: whether to normalize identifiers according to the dialect of interest. 114 115 Returns: 116 True if the column appears in the schema, False otherwise. 117 """ 118 name = column if isinstance(column, str) else column.name 119 return name in self.column_names(table, dialect=dialect, normalize=normalize) 120 121 def get_udf_type( 122 self, 123 udf: exp.Anonymous | str, 124 dialect: DialectType = None, 125 normalize: bool | None = None, 126 ) -> exp.DataType: 127 """ 128 Get the return type of a UDF. 129 130 Args: 131 udf: the UDF expression or string. 132 dialect: the SQL dialect for parsing string arguments. 133 normalize: whether to normalize identifiers. 134 135 Returns: 136 The return type as a DataType, or UNKNOWN if not found. 137 """ 138 return exp.DType.UNKNOWN.into_expr() 139 140 @property 141 @abc.abstractmethod 142 def supported_table_args(self) -> tuple[str, ...]: 143 """ 144 Table arguments this schema support, e.g. `("this", "db", "catalog")` 145 """ 146 147 @property 148 def empty(self) -> bool: 149 """Returns whether the schema is empty.""" 150 return True 151 152 153class AbstractMappingSchema: 154 def __init__( 155 self, 156 mapping: dict[str, object] | None = None, 157 udf_mapping: dict[str, object] | None = None, 158 ) -> None: 159 self.mapping: dict[str, object] = mapping or {} 160 self.mapping_trie: dict[str, object] = new_trie( 161 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 162 ) 163 self.udf_mapping: dict[str, object] = udf_mapping or {} 164 self.udf_trie: dict[str, object] = new_trie( 165 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 166 ) 167 168 self._supported_table_args: tuple[str, ...] = tuple() 169 170 @property 171 def empty(self) -> bool: 172 return not self.mapping 173 174 def depth(self) -> int: 175 return dict_depth(self.mapping) 176 177 def udf_depth(self) -> int: 178 return dict_depth(self.udf_mapping) 179 180 @property 181 def supported_table_args(self) -> tuple[str, ...]: 182 if not self._supported_table_args and self.mapping: 183 depth = self.depth() 184 185 if not depth: # None 186 self._supported_table_args = tuple() 187 elif 1 <= depth <= 3: 188 self._supported_table_args = exp.TABLE_PARTS[:depth] 189 else: 190 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 191 192 return self._supported_table_args 193 194 def table_parts(self, table: exp.Table) -> list[str]: 195 return [p.name for p in reversed(table.parts)] 196 197 def udf_parts(self, udf: exp.Anonymous) -> list[str]: 198 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 199 parent = udf.parent 200 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 201 return list(reversed(parts))[0 : self.udf_depth()] 202 203 def _find_in_trie( 204 self, 205 parts: list[str], 206 trie: dict[str, object], 207 raise_on_missing: bool, 208 ) -> list[str] | None: 209 value, trie = in_trie(trie, parts) 210 211 if value == TrieResult.FAILED: 212 return None 213 214 if value == TrieResult.PREFIX: 215 possibilities = flatten_schema(trie) 216 217 if len(possibilities) == 1: 218 parts.extend(possibilities[0]) 219 else: 220 if raise_on_missing: 221 joined_parts = ".".join(parts) 222 message = ", ".join(".".join(p) for p in possibilities) 223 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 224 225 return None 226 227 return parts 228 229 def find( 230 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 231 ) -> t.Any | None: 232 """ 233 Returns the schema of a given table. 234 235 Args: 236 table: the target table. 237 raise_on_missing: whether to raise in case the schema is not found. 238 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 239 240 Returns: 241 The schema of the target table. 242 """ 243 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 244 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 245 246 if resolved_parts is None: 247 return None 248 249 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 250 251 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None: 252 """ 253 Returns the return type of a given UDF. 254 255 Args: 256 udf: the target UDF expression. 257 raise_on_missing: whether to raise if the UDF is not found. 258 259 Returns: 260 The return type of the UDF, or None if not found. 261 """ 262 parts = self.udf_parts(udf) 263 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 264 265 if resolved_parts is None: 266 return None 267 268 return nested_get( 269 self.udf_mapping, 270 *zip(resolved_parts, reversed(resolved_parts)), 271 raise_on_missing=raise_on_missing, 272 ) 273 274 def nested_get( 275 self, 276 parts: Sequence[str], 277 d: dict[str, object] | None = None, 278 raise_on_missing: bool = True, 279 ) -> t.Any | None: 280 return nested_get( 281 d or self.mapping, 282 *zip(self.supported_table_args, reversed(parts)), 283 raise_on_missing=raise_on_missing, 284 ) 285 286 287class MappingSchema(AbstractMappingSchema, Schema): 288 """ 289 Schema based on a nested mapping. 290 291 Args: 292 schema: Mapping in one of the following forms: 293 1. {table: {col: type}} 294 2. {db: {table: {col: type}}} 295 3. {catalog: {db: {table: {col: type}}}} 296 4. None - Tables will be added later 297 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 298 are assumed to be visible. The nesting should mirror that of the schema: 299 1. {table: set(*cols)}} 300 2. {db: {table: set(*cols)}}} 301 3. {catalog: {db: {table: set(*cols)}}}} 302 dialect: The dialect to be used for custom type mappings & parsing string arguments. 303 normalize: Whether to normalize identifier names according to the given dialect or not. 304 """ 305 306 def __init__( 307 self, 308 schema: dict[str, object] | None = None, 309 visible: dict[str, object] | None = None, 310 dialect: DialectType = None, 311 normalize: bool = True, 312 udf_mapping: dict[str, object] | None = None, 313 ) -> None: 314 self.visible: dict[str, object] = {} if visible is None else visible 315 self.normalize: bool = normalize 316 self._dialect: Dialect = Dialect.get_or_raise(dialect) 317 self._type_mapping_cache: dict[str, exp.DataType] = {} 318 self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {} 319 self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {} 320 self._find_cache: dict[tuple[exp.Table, bool], dict[str, object] | None] = {} 321 self._depth: int = 0 322 schema = {} if schema is None else schema 323 udf_mapping = {} if udf_mapping is None else udf_mapping 324 325 super().__init__( 326 self._normalize(schema) if self.normalize else schema, 327 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 328 ) 329 330 @property 331 def dialect(self) -> Dialect: 332 """Returns the dialect for this mapping schema.""" 333 return self._dialect 334 335 @classmethod 336 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 337 return MappingSchema( 338 schema=mapping_schema.mapping, 339 visible=mapping_schema.visible, 340 dialect=mapping_schema.dialect, 341 normalize=mapping_schema.normalize, 342 udf_mapping=mapping_schema.udf_mapping, 343 ) 344 345 def find( 346 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 347 ) -> t.Any | None: 348 cache_key = (table, ensure_data_types) 349 schema = self._find_cache.get(cache_key) 350 351 if schema is None: 352 schema = super().find(table, raise_on_missing=raise_on_missing) 353 if ensure_data_types and isinstance(schema, dict): 354 schema = { 355 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 356 for col, dtype in schema.items() 357 } 358 self._find_cache[cache_key] = schema 359 360 return schema 361 362 def copy( 363 self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs] 364 ) -> MappingSchema: 365 mapping_kwargs: SchemaArgs = { 366 "visible": self.visible.copy(), 367 "dialect": self.dialect, 368 "normalize": self.normalize, 369 "udf_mapping": self.udf_mapping.copy(), 370 **kwargs, 371 } 372 return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs) 373 374 def add_table( 375 self, 376 table: exp.Table | str, 377 column_mapping: ColumnMapping | None = None, 378 dialect: DialectType = None, 379 normalize: bool | None = None, 380 match_depth: bool = True, 381 ) -> None: 382 """ 383 Register or update a table. Updates are only performed if a new column mapping is provided. 384 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 385 386 Args: 387 table: the `Table` expression instance or string representing the table. 388 column_mapping: a column mapping that describes the structure of the table. 389 dialect: the SQL dialect that will be used to parse `table` if it's a string. 390 normalize: whether to normalize identifiers according to the dialect of interest. 391 match_depth: whether to enforce that the table must match the schema's depth or not. 392 """ 393 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 394 395 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 396 raise SchemaError( 397 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 398 f"schema's nesting level: {self.depth()}." 399 ) 400 401 normalized_column_mapping = { 402 self._normalize_name(key, dialect=dialect, normalize=normalize): value 403 for key, value in ensure_column_mapping(column_mapping).items() 404 } 405 406 schema = self.find(normalized_table, raise_on_missing=False) 407 if schema and not normalized_column_mapping: 408 return 409 410 parts = self.table_parts(normalized_table) 411 412 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 413 new_trie([parts], self.mapping_trie) 414 self._find_cache.pop((normalized_table, True), None) 415 self._find_cache.pop((normalized_table, False), None) 416 417 def column_names( 418 self, 419 table: exp.Table | str, 420 only_visible: bool = False, 421 dialect: DialectType = None, 422 normalize: bool | None = None, 423 ) -> list[str]: 424 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 425 426 schema: dict[str, object] | None = self.find(normalized_table) 427 if schema is None: 428 return [] 429 430 if not only_visible or not self.visible: 431 return list(schema) 432 433 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 434 return [col for col in schema if col in visible] 435 436 def get_column_type( 437 self, 438 table: exp.Table | str, 439 column: exp.Column | str, 440 dialect: DialectType = None, 441 normalize: bool | None = None, 442 ) -> exp.DataType: 443 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 444 445 normalized_column_name = self._normalize_name( 446 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 447 ) 448 449 table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) 450 if table_schema: 451 column_type = table_schema.get(normalized_column_name) 452 453 if isinstance(column_type, exp.DataType): 454 return column_type 455 elif isinstance(column_type, str): 456 return self._to_data_type(column_type, dialect=dialect) 457 458 return exp.DType.UNKNOWN.into_expr() 459 460 def get_udf_type( 461 self, 462 udf: exp.Anonymous | str, 463 dialect: DialectType = None, 464 normalize: bool | None = None, 465 ) -> exp.DataType: 466 """ 467 Get the return type of a UDF. 468 469 Args: 470 udf: the UDF expression or string (e.g., "db.my_func()"). 471 dialect: the SQL dialect for parsing string arguments. 472 normalize: whether to normalize identifiers. 473 474 Returns: 475 The return type as a DataType, or UNKNOWN if not found. 476 """ 477 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 478 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 479 480 if resolved_parts is None: 481 return exp.DType.UNKNOWN.into_expr() 482 483 udf_type = nested_get( 484 self.udf_mapping, 485 *zip(resolved_parts, reversed(resolved_parts)), 486 raise_on_missing=False, 487 ) 488 489 if isinstance(udf_type, exp.DataType): 490 return udf_type 491 elif isinstance(udf_type, str): 492 return self._to_data_type(udf_type, dialect=dialect) 493 494 return exp.DType.UNKNOWN.into_expr() 495 496 def has_column( 497 self, 498 table: exp.Table | str, 499 column: exp.Column | str, 500 dialect: DialectType = None, 501 normalize: bool | None = None, 502 ) -> bool: 503 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 504 505 normalized_column_name = self._normalize_name( 506 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 507 ) 508 509 table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) 510 return normalized_column_name in table_schema if table_schema else False 511 512 def _normalize(self, schema: dict[str, object]) -> dict[str, object]: 513 """ 514 Normalizes all identifiers in the schema. 515 516 Args: 517 schema: the schema to normalize. 518 519 Returns: 520 The normalized schema mapping. 521 """ 522 normalized_mapping: dict[str, object] = {} 523 flattened_schema = flatten_schema(schema) 524 error_msg = "Table {} must match the schema's nesting level: {}." 525 526 for keys in flattened_schema: 527 columns = nested_get(schema, *zip(keys, keys)) 528 529 if not isinstance(columns, dict): 530 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 531 if not columns: 532 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 533 if isinstance(first(columns.values()), dict): 534 raise SchemaError( 535 error_msg.format( 536 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 537 ), 538 ) 539 540 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 541 for column_name, column_type in columns.items(): 542 nested_set( 543 normalized_mapping, 544 normalized_keys + [self._normalize_name(column_name)], 545 column_type, 546 ) 547 548 return normalized_mapping 549 550 def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]: 551 """ 552 Normalizes all identifiers in the UDF mapping. 553 554 Args: 555 udfs: the UDF mapping to normalize. 556 557 Returns: 558 The normalized UDF mapping. 559 """ 560 normalized_mapping: dict[str, object] = {} 561 562 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 563 udf_type = nested_get(udfs, *zip(keys, keys)) 564 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 565 nested_set(normalized_mapping, normalized_keys, udf_type) 566 567 return normalized_mapping 568 569 def _normalize_udf( 570 self, 571 udf: exp.Anonymous | str, 572 dialect: DialectType = None, 573 normalize: bool | None = None, 574 ) -> list[str]: 575 """ 576 Extract and normalize UDF parts for lookup. 577 578 Args: 579 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 580 dialect: the SQL dialect for parsing. 581 normalize: whether to normalize identifiers. 582 583 Returns: 584 A list of normalized UDF parts (reversed for trie lookup). 585 """ 586 dialect = dialect or self.dialect 587 normalize = self.normalize if normalize is None else normalize 588 589 if isinstance(udf, str): 590 parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect) 591 592 if isinstance(parsed, exp.Anonymous): 593 udf = parsed 594 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 595 udf = parsed.expression 596 else: 597 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 598 parts = self.udf_parts(udf) 599 600 if normalize: 601 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 602 603 return parts 604 605 def _normalize_table( 606 self, 607 table: exp.Table | str, 608 dialect: DialectType = None, 609 normalize: bool | None = None, 610 ) -> exp.Table: 611 dialect = dialect or self.dialect 612 normalize = self.normalize if normalize is None else normalize 613 614 # Cache normalized tables by object id for exp.Table inputs 615 # This is effective when the same Table object is looked up multiple times 616 if isinstance(table, exp.Table) and ( 617 cached := self._normalized_table_cache.get((table, dialect, normalize)) 618 ): 619 return cached 620 621 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 622 623 if normalize: 624 for part in normalized_table.parts: 625 if isinstance(part, exp.Identifier): 626 part.replace( 627 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 628 ) 629 630 self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table 631 return normalized_table 632 633 def _normalize_name( 634 self, 635 name: str | exp.Identifier, 636 dialect: DialectType = None, 637 is_table: bool = False, 638 normalize: bool | None = None, 639 ) -> str: 640 normalize = self.normalize if normalize is None else normalize 641 642 dialect = dialect or self.dialect 643 name_str = name if isinstance(name, str) else name.name 644 cache_key = (name_str, dialect, is_table, normalize) 645 646 if cached := self._normalized_name_cache.get(cache_key): 647 return cached 648 649 result = normalize_name( 650 name, 651 dialect=dialect, 652 is_table=is_table, 653 normalize=normalize, 654 ).name 655 656 self._normalized_name_cache[cache_key] = result 657 return result 658 659 def depth(self) -> int: 660 if not self.empty and not self._depth: 661 # The columns themselves are a mapping, but we don't want to include those 662 self._depth = super().depth() - 1 663 return self._depth 664 665 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 666 """ 667 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 668 669 Args: 670 schema_type: the type we want to convert. 671 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 672 673 Returns: 674 The resulting expression type. 675 """ 676 if schema_type not in self._type_mapping_cache: 677 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 678 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 679 680 try: 681 expression = exp.DataType.from_str(schema_type, dialect=dialect, udt=udt) 682 expression.transform(dialect.normalize_identifier, copy=False) 683 self._type_mapping_cache[schema_type] = expression 684 except AttributeError: 685 in_dialect = f" in dialect {dialect}" if dialect else "" 686 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 687 688 return self._type_mapping_cache[schema_type] 689 690 691def normalize_name( 692 identifier: str | exp.Identifier, 693 dialect: DialectType = None, 694 is_table: bool = False, 695 normalize: bool | None = True, 696) -> exp.Identifier: 697 if isinstance(identifier, str): 698 identifier = exp.parse_identifier(identifier, dialect=dialect) 699 700 if not normalize: 701 return identifier 702 703 # this is used for normalize_identifier, bigquery has special rules pertaining tables 704 identifier.meta["is_table"] = is_table 705 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 706 707 708def ensure_schema( 709 schema: Schema | dict[str, object] | None, **kwargs: Unpack[SchemaArgs] 710) -> Schema: 711 if isinstance(schema, Schema): 712 return schema 713 714 return MappingSchema(schema, **kwargs) 715 716 717def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, t.Any]: 718 if mapping is None: 719 return {} 720 elif isinstance(mapping, dict): 721 return mapping 722 elif isinstance(mapping, str): 723 col_name_type_strs = [x.strip() for x in mapping.split(",")] 724 return { 725 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 726 for name_type_str in col_name_type_strs 727 } 728 elif isinstance(mapping, list): 729 return {x.strip(): None for x in mapping} 730 731 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 732 733 734def flatten_schema( 735 schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None 736) -> list[list[str]]: 737 tables: list[list[str]] = [] 738 keys = keys or [] 739 depth = dict_depth(schema) - 1 if depth is None else depth 740 741 for k, v in schema.items(): 742 if depth == 1 or not isinstance(v, dict): 743 tables.append(keys + [k]) 744 elif depth >= 2: 745 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 746 747 return tables 748 749 750def nested_get( 751 d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True 752) -> t.Any | None: 753 """ 754 Get a value for a nested dictionary. 755 756 Args: 757 d: the dictionary to search. 758 *path: tuples of (name, key), where: 759 `key` is the key in the dictionary to get. 760 `name` is a string to use in the error if `key` isn't found. 761 762 Returns: 763 The value or None if it doesn't exist. 764 """ 765 result: t.Any = d 766 for name, key in path: 767 result = result.get(key) 768 if result is None: 769 if raise_on_missing: 770 name = "table" if name == "this" else name 771 raise ValueError(f"Unknown {name}: {key}") 772 return None 773 774 return result 775 776 777def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]: 778 """ 779 In-place set a value for a nested dictionary 780 781 Example: 782 >>> nested_set({}, ["top_key", "second_key"], "value") 783 {'top_key': {'second_key': 'value'}} 784 785 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 786 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 787 788 Args: 789 d: dictionary to update. 790 keys: the keys that makeup the path to `value`. 791 value: the value to set in the dictionary for the given key path. 792 793 Returns: 794 The (possibly) updated dictionary. 795 """ 796 if not keys: 797 return d 798 799 if len(keys) == 1: 800 d[keys[0]] = value 801 return d 802 803 subd = d 804 for key in keys[:-1]: 805 if key not in subd: 806 subd = subd.setdefault(key, {}) 807 else: 808 subd = subd[key] 809 810 subd[keys[-1]] = value 811 return d
25@trait 26class Schema(abc.ABC): 27 """Abstract base class for database schemas""" 28 29 @property 30 def dialect(self) -> Dialect | None: 31 """ 32 Returns None by default. Subclasses that require dialect-specific 33 behavior should override this property. 34 """ 35 return None 36 37 @abc.abstractmethod 38 def add_table( 39 self, 40 table: exp.Table | str, 41 column_mapping: ColumnMapping | None = None, 42 dialect: DialectType = None, 43 normalize: bool | None = None, 44 match_depth: bool = True, 45 ) -> None: 46 """ 47 Register or update a table. Some implementing classes may require column information to also be provided. 48 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 49 50 Args: 51 table: the `Table` expression instance or string representing the table. 52 column_mapping: a column mapping that describes the structure of the table. 53 dialect: the SQL dialect that will be used to parse `table` if it's a string. 54 normalize: whether to normalize identifiers according to the dialect of interest. 55 match_depth: whether to enforce that the table must match the schema's depth or not. 56 """ 57 58 @abc.abstractmethod 59 def column_names( 60 self, 61 table: exp.Table | str, 62 only_visible: bool = False, 63 dialect: DialectType = None, 64 normalize: bool | None = None, 65 ) -> Sequence[str]: 66 """ 67 Get the column names for a table. 68 69 Args: 70 table: the `Table` expression instance. 71 only_visible: whether to include invisible columns. 72 dialect: the SQL dialect that will be used to parse `table` if it's a string. 73 normalize: whether to normalize identifiers according to the dialect of interest. 74 75 Returns: 76 The sequence of column names. 77 """ 78 79 @abc.abstractmethod 80 def get_column_type( 81 self, 82 table: exp.Table | str, 83 column: exp.Column | str, 84 dialect: DialectType = None, 85 normalize: bool | None = None, 86 ) -> exp.DataType: 87 """ 88 Get the `sqlglot.exp.DataType` type of a column in the schema. 89 90 Args: 91 table: the source table. 92 column: the target column. 93 dialect: the SQL dialect that will be used to parse `table` if it's a string. 94 normalize: whether to normalize identifiers according to the dialect of interest. 95 96 Returns: 97 The resulting column type. 98 """ 99 100 def has_column( 101 self, 102 table: exp.Table | str, 103 column: exp.Column | str, 104 dialect: DialectType = None, 105 normalize: bool | None = None, 106 ) -> bool: 107 """ 108 Returns whether `column` appears in `table`'s schema. 109 110 Args: 111 table: the source table. 112 column: the target column. 113 dialect: the SQL dialect that will be used to parse `table` if it's a string. 114 normalize: whether to normalize identifiers according to the dialect of interest. 115 116 Returns: 117 True if the column appears in the schema, False otherwise. 118 """ 119 name = column if isinstance(column, str) else column.name 120 return name in self.column_names(table, dialect=dialect, normalize=normalize) 121 122 def get_udf_type( 123 self, 124 udf: exp.Anonymous | str, 125 dialect: DialectType = None, 126 normalize: bool | None = None, 127 ) -> exp.DataType: 128 """ 129 Get the return type of a UDF. 130 131 Args: 132 udf: the UDF expression or string. 133 dialect: the SQL dialect for parsing string arguments. 134 normalize: whether to normalize identifiers. 135 136 Returns: 137 The return type as a DataType, or UNKNOWN if not found. 138 """ 139 return exp.DType.UNKNOWN.into_expr() 140 141 @property 142 @abc.abstractmethod 143 def supported_table_args(self) -> tuple[str, ...]: 144 """ 145 Table arguments this schema support, e.g. `("this", "db", "catalog")` 146 """ 147 148 @property 149 def empty(self) -> bool: 150 """Returns whether the schema is empty.""" 151 return True
Abstract base class for database schemas
29 @property 30 def dialect(self) -> Dialect | None: 31 """ 32 Returns None by default. Subclasses that require dialect-specific 33 behavior should override this property. 34 """ 35 return None
Returns None by default. Subclasses that require dialect-specific behavior should override this property.
37 @abc.abstractmethod 38 def add_table( 39 self, 40 table: exp.Table | str, 41 column_mapping: ColumnMapping | None = None, 42 dialect: DialectType = None, 43 normalize: bool | None = None, 44 match_depth: bool = True, 45 ) -> None: 46 """ 47 Register or update a table. Some implementing classes may require column information to also be provided. 48 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 49 50 Args: 51 table: the `Table` expression instance or string representing the table. 52 column_mapping: a column mapping that describes the structure of the table. 53 dialect: the SQL dialect that will be used to parse `table` if it's a string. 54 normalize: whether to normalize identifiers according to the dialect of interest. 55 match_depth: whether to enforce that the table must match the schema's depth or not. 56 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
58 @abc.abstractmethod 59 def column_names( 60 self, 61 table: exp.Table | str, 62 only_visible: bool = False, 63 dialect: DialectType = None, 64 normalize: bool | None = None, 65 ) -> Sequence[str]: 66 """ 67 Get the column names for a table. 68 69 Args: 70 table: the `Table` expression instance. 71 only_visible: whether to include invisible columns. 72 dialect: the SQL dialect that will be used to parse `table` if it's a string. 73 normalize: whether to normalize identifiers according to the dialect of interest. 74 75 Returns: 76 The sequence of column names. 77 """
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
79 @abc.abstractmethod 80 def get_column_type( 81 self, 82 table: exp.Table | str, 83 column: exp.Column | str, 84 dialect: DialectType = None, 85 normalize: bool | None = None, 86 ) -> exp.DataType: 87 """ 88 Get the `sqlglot.exp.DataType` type of a column in the schema. 89 90 Args: 91 table: the source table. 92 column: the target column. 93 dialect: the SQL dialect that will be used to parse `table` if it's a string. 94 normalize: whether to normalize identifiers according to the dialect of interest. 95 96 Returns: 97 The resulting column type. 98 """
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
100 def has_column( 101 self, 102 table: exp.Table | str, 103 column: exp.Column | str, 104 dialect: DialectType = None, 105 normalize: bool | None = None, 106 ) -> bool: 107 """ 108 Returns whether `column` appears in `table`'s schema. 109 110 Args: 111 table: the source table. 112 column: the target column. 113 dialect: the SQL dialect that will be used to parse `table` if it's a string. 114 normalize: whether to normalize identifiers according to the dialect of interest. 115 116 Returns: 117 True if the column appears in the schema, False otherwise. 118 """ 119 name = column if isinstance(column, str) else column.name 120 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
122 def get_udf_type( 123 self, 124 udf: exp.Anonymous | str, 125 dialect: DialectType = None, 126 normalize: bool | None = None, 127 ) -> exp.DataType: 128 """ 129 Get the return type of a UDF. 130 131 Args: 132 udf: the UDF expression or string. 133 dialect: the SQL dialect for parsing string arguments. 134 normalize: whether to normalize identifiers. 135 136 Returns: 137 The return type as a DataType, or UNKNOWN if not found. 138 """ 139 return exp.DType.UNKNOWN.into_expr()
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string.
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
154class AbstractMappingSchema: 155 def __init__( 156 self, 157 mapping: dict[str, object] | None = None, 158 udf_mapping: dict[str, object] | None = None, 159 ) -> None: 160 self.mapping: dict[str, object] = mapping or {} 161 self.mapping_trie: dict[str, object] = new_trie( 162 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 163 ) 164 self.udf_mapping: dict[str, object] = udf_mapping or {} 165 self.udf_trie: dict[str, object] = new_trie( 166 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 167 ) 168 169 self._supported_table_args: tuple[str, ...] = tuple() 170 171 @property 172 def empty(self) -> bool: 173 return not self.mapping 174 175 def depth(self) -> int: 176 return dict_depth(self.mapping) 177 178 def udf_depth(self) -> int: 179 return dict_depth(self.udf_mapping) 180 181 @property 182 def supported_table_args(self) -> tuple[str, ...]: 183 if not self._supported_table_args and self.mapping: 184 depth = self.depth() 185 186 if not depth: # None 187 self._supported_table_args = tuple() 188 elif 1 <= depth <= 3: 189 self._supported_table_args = exp.TABLE_PARTS[:depth] 190 else: 191 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 192 193 return self._supported_table_args 194 195 def table_parts(self, table: exp.Table) -> list[str]: 196 return [p.name for p in reversed(table.parts)] 197 198 def udf_parts(self, udf: exp.Anonymous) -> list[str]: 199 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 200 parent = udf.parent 201 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 202 return list(reversed(parts))[0 : self.udf_depth()] 203 204 def _find_in_trie( 205 self, 206 parts: list[str], 207 trie: dict[str, object], 208 raise_on_missing: bool, 209 ) -> list[str] | None: 210 value, trie = in_trie(trie, parts) 211 212 if value == TrieResult.FAILED: 213 return None 214 215 if value == TrieResult.PREFIX: 216 possibilities = flatten_schema(trie) 217 218 if len(possibilities) == 1: 219 parts.extend(possibilities[0]) 220 else: 221 if raise_on_missing: 222 joined_parts = ".".join(parts) 223 message = ", ".join(".".join(p) for p in possibilities) 224 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 225 226 return None 227 228 return parts 229 230 def find( 231 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 232 ) -> t.Any | None: 233 """ 234 Returns the schema of a given table. 235 236 Args: 237 table: the target table. 238 raise_on_missing: whether to raise in case the schema is not found. 239 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 240 241 Returns: 242 The schema of the target table. 243 """ 244 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 245 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 246 247 if resolved_parts is None: 248 return None 249 250 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 251 252 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None: 253 """ 254 Returns the return type of a given UDF. 255 256 Args: 257 udf: the target UDF expression. 258 raise_on_missing: whether to raise if the UDF is not found. 259 260 Returns: 261 The return type of the UDF, or None if not found. 262 """ 263 parts = self.udf_parts(udf) 264 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 265 266 if resolved_parts is None: 267 return None 268 269 return nested_get( 270 self.udf_mapping, 271 *zip(resolved_parts, reversed(resolved_parts)), 272 raise_on_missing=raise_on_missing, 273 ) 274 275 def nested_get( 276 self, 277 parts: Sequence[str], 278 d: dict[str, object] | None = None, 279 raise_on_missing: bool = True, 280 ) -> t.Any | None: 281 return nested_get( 282 d or self.mapping, 283 *zip(self.supported_table_args, reversed(parts)), 284 raise_on_missing=raise_on_missing, 285 )
155 def __init__( 156 self, 157 mapping: dict[str, object] | None = None, 158 udf_mapping: dict[str, object] | None = None, 159 ) -> None: 160 self.mapping: dict[str, object] = mapping or {} 161 self.mapping_trie: dict[str, object] = new_trie( 162 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 163 ) 164 self.udf_mapping: dict[str, object] = udf_mapping or {} 165 self.udf_trie: dict[str, object] = new_trie( 166 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 167 ) 168 169 self._supported_table_args: tuple[str, ...] = tuple()
181 @property 182 def supported_table_args(self) -> tuple[str, ...]: 183 if not self._supported_table_args and self.mapping: 184 depth = self.depth() 185 186 if not depth: # None 187 self._supported_table_args = tuple() 188 elif 1 <= depth <= 3: 189 self._supported_table_args = exp.TABLE_PARTS[:depth] 190 else: 191 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 192 193 return self._supported_table_args
230 def find( 231 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 232 ) -> t.Any | None: 233 """ 234 Returns the schema of a given table. 235 236 Args: 237 table: the target table. 238 raise_on_missing: whether to raise in case the schema is not found. 239 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 240 241 Returns: 242 The schema of the target table. 243 """ 244 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 245 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 246 247 if resolved_parts is None: 248 return None 249 250 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
252 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None: 253 """ 254 Returns the return type of a given UDF. 255 256 Args: 257 udf: the target UDF expression. 258 raise_on_missing: whether to raise if the UDF is not found. 259 260 Returns: 261 The return type of the UDF, or None if not found. 262 """ 263 parts = self.udf_parts(udf) 264 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 265 266 if resolved_parts is None: 267 return None 268 269 return nested_get( 270 self.udf_mapping, 271 *zip(resolved_parts, reversed(resolved_parts)), 272 raise_on_missing=raise_on_missing, 273 )
Returns the return type of a given UDF.
Arguments:
- udf: the target UDF expression.
- raise_on_missing: whether to raise if the UDF is not found.
Returns:
The return type of the UDF, or None if not found.
288class MappingSchema(AbstractMappingSchema, Schema): 289 """ 290 Schema based on a nested mapping. 291 292 Args: 293 schema: Mapping in one of the following forms: 294 1. {table: {col: type}} 295 2. {db: {table: {col: type}}} 296 3. {catalog: {db: {table: {col: type}}}} 297 4. None - Tables will be added later 298 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 299 are assumed to be visible. The nesting should mirror that of the schema: 300 1. {table: set(*cols)}} 301 2. {db: {table: set(*cols)}}} 302 3. {catalog: {db: {table: set(*cols)}}}} 303 dialect: The dialect to be used for custom type mappings & parsing string arguments. 304 normalize: Whether to normalize identifier names according to the given dialect or not. 305 """ 306 307 def __init__( 308 self, 309 schema: dict[str, object] | None = None, 310 visible: dict[str, object] | None = None, 311 dialect: DialectType = None, 312 normalize: bool = True, 313 udf_mapping: dict[str, object] | None = None, 314 ) -> None: 315 self.visible: dict[str, object] = {} if visible is None else visible 316 self.normalize: bool = normalize 317 self._dialect: Dialect = Dialect.get_or_raise(dialect) 318 self._type_mapping_cache: dict[str, exp.DataType] = {} 319 self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {} 320 self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {} 321 self._find_cache: dict[tuple[exp.Table, bool], dict[str, object] | None] = {} 322 self._depth: int = 0 323 schema = {} if schema is None else schema 324 udf_mapping = {} if udf_mapping is None else udf_mapping 325 326 super().__init__( 327 self._normalize(schema) if self.normalize else schema, 328 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 329 ) 330 331 @property 332 def dialect(self) -> Dialect: 333 """Returns the dialect for this mapping schema.""" 334 return self._dialect 335 336 @classmethod 337 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 338 return MappingSchema( 339 schema=mapping_schema.mapping, 340 visible=mapping_schema.visible, 341 dialect=mapping_schema.dialect, 342 normalize=mapping_schema.normalize, 343 udf_mapping=mapping_schema.udf_mapping, 344 ) 345 346 def find( 347 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 348 ) -> t.Any | None: 349 cache_key = (table, ensure_data_types) 350 schema = self._find_cache.get(cache_key) 351 352 if schema is None: 353 schema = super().find(table, raise_on_missing=raise_on_missing) 354 if ensure_data_types and isinstance(schema, dict): 355 schema = { 356 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 357 for col, dtype in schema.items() 358 } 359 self._find_cache[cache_key] = schema 360 361 return schema 362 363 def copy( 364 self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs] 365 ) -> MappingSchema: 366 mapping_kwargs: SchemaArgs = { 367 "visible": self.visible.copy(), 368 "dialect": self.dialect, 369 "normalize": self.normalize, 370 "udf_mapping": self.udf_mapping.copy(), 371 **kwargs, 372 } 373 return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs) 374 375 def add_table( 376 self, 377 table: exp.Table | str, 378 column_mapping: ColumnMapping | None = None, 379 dialect: DialectType = None, 380 normalize: bool | None = None, 381 match_depth: bool = True, 382 ) -> None: 383 """ 384 Register or update a table. Updates are only performed if a new column mapping is provided. 385 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 386 387 Args: 388 table: the `Table` expression instance or string representing the table. 389 column_mapping: a column mapping that describes the structure of the table. 390 dialect: the SQL dialect that will be used to parse `table` if it's a string. 391 normalize: whether to normalize identifiers according to the dialect of interest. 392 match_depth: whether to enforce that the table must match the schema's depth or not. 393 """ 394 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 395 396 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 397 raise SchemaError( 398 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 399 f"schema's nesting level: {self.depth()}." 400 ) 401 402 normalized_column_mapping = { 403 self._normalize_name(key, dialect=dialect, normalize=normalize): value 404 for key, value in ensure_column_mapping(column_mapping).items() 405 } 406 407 schema = self.find(normalized_table, raise_on_missing=False) 408 if schema and not normalized_column_mapping: 409 return 410 411 parts = self.table_parts(normalized_table) 412 413 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 414 new_trie([parts], self.mapping_trie) 415 self._find_cache.pop((normalized_table, True), None) 416 self._find_cache.pop((normalized_table, False), None) 417 418 def column_names( 419 self, 420 table: exp.Table | str, 421 only_visible: bool = False, 422 dialect: DialectType = None, 423 normalize: bool | None = None, 424 ) -> list[str]: 425 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 426 427 schema: dict[str, object] | None = self.find(normalized_table) 428 if schema is None: 429 return [] 430 431 if not only_visible or not self.visible: 432 return list(schema) 433 434 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 435 return [col for col in schema if col in visible] 436 437 def get_column_type( 438 self, 439 table: exp.Table | str, 440 column: exp.Column | str, 441 dialect: DialectType = None, 442 normalize: bool | None = None, 443 ) -> exp.DataType: 444 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 445 446 normalized_column_name = self._normalize_name( 447 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 448 ) 449 450 table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) 451 if table_schema: 452 column_type = table_schema.get(normalized_column_name) 453 454 if isinstance(column_type, exp.DataType): 455 return column_type 456 elif isinstance(column_type, str): 457 return self._to_data_type(column_type, dialect=dialect) 458 459 return exp.DType.UNKNOWN.into_expr() 460 461 def get_udf_type( 462 self, 463 udf: exp.Anonymous | str, 464 dialect: DialectType = None, 465 normalize: bool | None = None, 466 ) -> exp.DataType: 467 """ 468 Get the return type of a UDF. 469 470 Args: 471 udf: the UDF expression or string (e.g., "db.my_func()"). 472 dialect: the SQL dialect for parsing string arguments. 473 normalize: whether to normalize identifiers. 474 475 Returns: 476 The return type as a DataType, or UNKNOWN if not found. 477 """ 478 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 479 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 480 481 if resolved_parts is None: 482 return exp.DType.UNKNOWN.into_expr() 483 484 udf_type = nested_get( 485 self.udf_mapping, 486 *zip(resolved_parts, reversed(resolved_parts)), 487 raise_on_missing=False, 488 ) 489 490 if isinstance(udf_type, exp.DataType): 491 return udf_type 492 elif isinstance(udf_type, str): 493 return self._to_data_type(udf_type, dialect=dialect) 494 495 return exp.DType.UNKNOWN.into_expr() 496 497 def has_column( 498 self, 499 table: exp.Table | str, 500 column: exp.Column | str, 501 dialect: DialectType = None, 502 normalize: bool | None = None, 503 ) -> bool: 504 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 505 506 normalized_column_name = self._normalize_name( 507 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 508 ) 509 510 table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) 511 return normalized_column_name in table_schema if table_schema else False 512 513 def _normalize(self, schema: dict[str, object]) -> dict[str, object]: 514 """ 515 Normalizes all identifiers in the schema. 516 517 Args: 518 schema: the schema to normalize. 519 520 Returns: 521 The normalized schema mapping. 522 """ 523 normalized_mapping: dict[str, object] = {} 524 flattened_schema = flatten_schema(schema) 525 error_msg = "Table {} must match the schema's nesting level: {}." 526 527 for keys in flattened_schema: 528 columns = nested_get(schema, *zip(keys, keys)) 529 530 if not isinstance(columns, dict): 531 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 532 if not columns: 533 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 534 if isinstance(first(columns.values()), dict): 535 raise SchemaError( 536 error_msg.format( 537 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 538 ), 539 ) 540 541 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 542 for column_name, column_type in columns.items(): 543 nested_set( 544 normalized_mapping, 545 normalized_keys + [self._normalize_name(column_name)], 546 column_type, 547 ) 548 549 return normalized_mapping 550 551 def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]: 552 """ 553 Normalizes all identifiers in the UDF mapping. 554 555 Args: 556 udfs: the UDF mapping to normalize. 557 558 Returns: 559 The normalized UDF mapping. 560 """ 561 normalized_mapping: dict[str, object] = {} 562 563 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 564 udf_type = nested_get(udfs, *zip(keys, keys)) 565 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 566 nested_set(normalized_mapping, normalized_keys, udf_type) 567 568 return normalized_mapping 569 570 def _normalize_udf( 571 self, 572 udf: exp.Anonymous | str, 573 dialect: DialectType = None, 574 normalize: bool | None = None, 575 ) -> list[str]: 576 """ 577 Extract and normalize UDF parts for lookup. 578 579 Args: 580 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 581 dialect: the SQL dialect for parsing. 582 normalize: whether to normalize identifiers. 583 584 Returns: 585 A list of normalized UDF parts (reversed for trie lookup). 586 """ 587 dialect = dialect or self.dialect 588 normalize = self.normalize if normalize is None else normalize 589 590 if isinstance(udf, str): 591 parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect) 592 593 if isinstance(parsed, exp.Anonymous): 594 udf = parsed 595 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 596 udf = parsed.expression 597 else: 598 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 599 parts = self.udf_parts(udf) 600 601 if normalize: 602 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 603 604 return parts 605 606 def _normalize_table( 607 self, 608 table: exp.Table | str, 609 dialect: DialectType = None, 610 normalize: bool | None = None, 611 ) -> exp.Table: 612 dialect = dialect or self.dialect 613 normalize = self.normalize if normalize is None else normalize 614 615 # Cache normalized tables by object id for exp.Table inputs 616 # This is effective when the same Table object is looked up multiple times 617 if isinstance(table, exp.Table) and ( 618 cached := self._normalized_table_cache.get((table, dialect, normalize)) 619 ): 620 return cached 621 622 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 623 624 if normalize: 625 for part in normalized_table.parts: 626 if isinstance(part, exp.Identifier): 627 part.replace( 628 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 629 ) 630 631 self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table 632 return normalized_table 633 634 def _normalize_name( 635 self, 636 name: str | exp.Identifier, 637 dialect: DialectType = None, 638 is_table: bool = False, 639 normalize: bool | None = None, 640 ) -> str: 641 normalize = self.normalize if normalize is None else normalize 642 643 dialect = dialect or self.dialect 644 name_str = name if isinstance(name, str) else name.name 645 cache_key = (name_str, dialect, is_table, normalize) 646 647 if cached := self._normalized_name_cache.get(cache_key): 648 return cached 649 650 result = normalize_name( 651 name, 652 dialect=dialect, 653 is_table=is_table, 654 normalize=normalize, 655 ).name 656 657 self._normalized_name_cache[cache_key] = result 658 return result 659 660 def depth(self) -> int: 661 if not self.empty and not self._depth: 662 # The columns themselves are a mapping, but we don't want to include those 663 self._depth = super().depth() - 1 664 return self._depth 665 666 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 667 """ 668 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 669 670 Args: 671 schema_type: the type we want to convert. 672 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 673 674 Returns: 675 The resulting expression type. 676 """ 677 if schema_type not in self._type_mapping_cache: 678 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 679 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 680 681 try: 682 expression = exp.DataType.from_str(schema_type, dialect=dialect, udt=udt) 683 expression.transform(dialect.normalize_identifier, copy=False) 684 self._type_mapping_cache[schema_type] = expression 685 except AttributeError: 686 in_dialect = f" in dialect {dialect}" if dialect else "" 687 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 688 689 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(*cols)}}
- {db: {table: set(*cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
307 def __init__( 308 self, 309 schema: dict[str, object] | None = None, 310 visible: dict[str, object] | None = None, 311 dialect: DialectType = None, 312 normalize: bool = True, 313 udf_mapping: dict[str, object] | None = None, 314 ) -> None: 315 self.visible: dict[str, object] = {} if visible is None else visible 316 self.normalize: bool = normalize 317 self._dialect: Dialect = Dialect.get_or_raise(dialect) 318 self._type_mapping_cache: dict[str, exp.DataType] = {} 319 self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {} 320 self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {} 321 self._find_cache: dict[tuple[exp.Table, bool], dict[str, object] | None] = {} 322 self._depth: int = 0 323 schema = {} if schema is None else schema 324 udf_mapping = {} if udf_mapping is None else udf_mapping 325 326 super().__init__( 327 self._normalize(schema) if self.normalize else schema, 328 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 329 )
331 @property 332 def dialect(self) -> Dialect: 333 """Returns the dialect for this mapping schema.""" 334 return self._dialect
Returns the dialect for this mapping schema.
336 @classmethod 337 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 338 return MappingSchema( 339 schema=mapping_schema.mapping, 340 visible=mapping_schema.visible, 341 dialect=mapping_schema.dialect, 342 normalize=mapping_schema.normalize, 343 udf_mapping=mapping_schema.udf_mapping, 344 )
346 def find( 347 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 348 ) -> t.Any | None: 349 cache_key = (table, ensure_data_types) 350 schema = self._find_cache.get(cache_key) 351 352 if schema is None: 353 schema = super().find(table, raise_on_missing=raise_on_missing) 354 if ensure_data_types and isinstance(schema, dict): 355 schema = { 356 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 357 for col, dtype in schema.items() 358 } 359 self._find_cache[cache_key] = schema 360 361 return schema
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
363 def copy( 364 self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs] 365 ) -> MappingSchema: 366 mapping_kwargs: SchemaArgs = { 367 "visible": self.visible.copy(), 368 "dialect": self.dialect, 369 "normalize": self.normalize, 370 "udf_mapping": self.udf_mapping.copy(), 371 **kwargs, 372 } 373 return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
375 def add_table( 376 self, 377 table: exp.Table | str, 378 column_mapping: ColumnMapping | None = None, 379 dialect: DialectType = None, 380 normalize: bool | None = None, 381 match_depth: bool = True, 382 ) -> None: 383 """ 384 Register or update a table. Updates are only performed if a new column mapping is provided. 385 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 386 387 Args: 388 table: the `Table` expression instance or string representing the table. 389 column_mapping: a column mapping that describes the structure of the table. 390 dialect: the SQL dialect that will be used to parse `table` if it's a string. 391 normalize: whether to normalize identifiers according to the dialect of interest. 392 match_depth: whether to enforce that the table must match the schema's depth or not. 393 """ 394 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 395 396 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 397 raise SchemaError( 398 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 399 f"schema's nesting level: {self.depth()}." 400 ) 401 402 normalized_column_mapping = { 403 self._normalize_name(key, dialect=dialect, normalize=normalize): value 404 for key, value in ensure_column_mapping(column_mapping).items() 405 } 406 407 schema = self.find(normalized_table, raise_on_missing=False) 408 if schema and not normalized_column_mapping: 409 return 410 411 parts = self.table_parts(normalized_table) 412 413 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 414 new_trie([parts], self.mapping_trie) 415 self._find_cache.pop((normalized_table, True), None) 416 self._find_cache.pop((normalized_table, False), None)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
418 def column_names( 419 self, 420 table: exp.Table | str, 421 only_visible: bool = False, 422 dialect: DialectType = None, 423 normalize: bool | None = None, 424 ) -> list[str]: 425 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 426 427 schema: dict[str, object] | None = self.find(normalized_table) 428 if schema is None: 429 return [] 430 431 if not only_visible or not self.visible: 432 return list(schema) 433 434 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 435 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
437 def get_column_type( 438 self, 439 table: exp.Table | str, 440 column: exp.Column | str, 441 dialect: DialectType = None, 442 normalize: bool | None = None, 443 ) -> exp.DataType: 444 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 445 446 normalized_column_name = self._normalize_name( 447 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 448 ) 449 450 table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) 451 if table_schema: 452 column_type = table_schema.get(normalized_column_name) 453 454 if isinstance(column_type, exp.DataType): 455 return column_type 456 elif isinstance(column_type, str): 457 return self._to_data_type(column_type, dialect=dialect) 458 459 return exp.DType.UNKNOWN.into_expr()
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
461 def get_udf_type( 462 self, 463 udf: exp.Anonymous | str, 464 dialect: DialectType = None, 465 normalize: bool | None = None, 466 ) -> exp.DataType: 467 """ 468 Get the return type of a UDF. 469 470 Args: 471 udf: the UDF expression or string (e.g., "db.my_func()"). 472 dialect: the SQL dialect for parsing string arguments. 473 normalize: whether to normalize identifiers. 474 475 Returns: 476 The return type as a DataType, or UNKNOWN if not found. 477 """ 478 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 479 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 480 481 if resolved_parts is None: 482 return exp.DType.UNKNOWN.into_expr() 483 484 udf_type = nested_get( 485 self.udf_mapping, 486 *zip(resolved_parts, reversed(resolved_parts)), 487 raise_on_missing=False, 488 ) 489 490 if isinstance(udf_type, exp.DataType): 491 return udf_type 492 elif isinstance(udf_type, str): 493 return self._to_data_type(udf_type, dialect=dialect) 494 495 return exp.DType.UNKNOWN.into_expr()
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string (e.g., "db.my_func()").
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
497 def has_column( 498 self, 499 table: exp.Table | str, 500 column: exp.Column | str, 501 dialect: DialectType = None, 502 normalize: bool | None = None, 503 ) -> bool: 504 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 505 506 normalized_column_name = self._normalize_name( 507 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 508 ) 509 510 table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False) 511 return normalized_column_name in table_schema if table_schema else False
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
692def normalize_name( 693 identifier: str | exp.Identifier, 694 dialect: DialectType = None, 695 is_table: bool = False, 696 normalize: bool | None = True, 697) -> exp.Identifier: 698 if isinstance(identifier, str): 699 identifier = exp.parse_identifier(identifier, dialect=dialect) 700 701 if not normalize: 702 return identifier 703 704 # this is used for normalize_identifier, bigquery has special rules pertaining tables 705 identifier.meta["is_table"] = is_table 706 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
718def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, t.Any]: 719 if mapping is None: 720 return {} 721 elif isinstance(mapping, dict): 722 return mapping 723 elif isinstance(mapping, str): 724 col_name_type_strs = [x.strip() for x in mapping.split(",")] 725 return { 726 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 727 for name_type_str in col_name_type_strs 728 } 729 elif isinstance(mapping, list): 730 return {x.strip(): None for x in mapping} 731 732 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
735def flatten_schema( 736 schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None 737) -> list[list[str]]: 738 tables: list[list[str]] = [] 739 keys = keys or [] 740 depth = dict_depth(schema) - 1 if depth is None else depth 741 742 for k, v in schema.items(): 743 if depth == 1 or not isinstance(v, dict): 744 tables.append(keys + [k]) 745 elif depth >= 2: 746 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 747 748 return tables
751def nested_get( 752 d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True 753) -> t.Any | None: 754 """ 755 Get a value for a nested dictionary. 756 757 Args: 758 d: the dictionary to search. 759 *path: tuples of (name, key), where: 760 `key` is the key in the dictionary to get. 761 `name` is a string to use in the error if `key` isn't found. 762 763 Returns: 764 The value or None if it doesn't exist. 765 """ 766 result: t.Any = d 767 for name, key in path: 768 result = result.get(key) 769 if result is None: 770 if raise_on_missing: 771 name = "table" if name == "this" else name 772 raise ValueError(f"Unknown {name}: {key}") 773 return None 774 775 return result
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
keyis the key in the dictionary to get.nameis a string to use in the error ifkeyisn't found.
Returns:
The value or None if it doesn't exist.
778def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]: 779 """ 780 In-place set a value for a nested dictionary 781 782 Example: 783 >>> nested_set({}, ["top_key", "second_key"], "value") 784 {'top_key': {'second_key': 'value'}} 785 786 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 787 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 788 789 Args: 790 d: dictionary to update. 791 keys: the keys that makeup the path to `value`. 792 value: the value to set in the dictionary for the given key path. 793 794 Returns: 795 The (possibly) updated dictionary. 796 """ 797 if not keys: 798 return d 799 800 if len(keys) == 1: 801 d[keys[0]] = value 802 return d 803 804 subd = d 805 for key in keys[:-1]: 806 if key not in subd: 807 subd = subd.setdefault(key, {}) 808 else: 809 subd = subd[key] 810 811 subd[keys[-1]] = value 812 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.