sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12from sqlglot.helper import trait 13 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import SchemaArgs 17 from sqlglot.dialects.dialect import DialectType 18 from collections.abc import Sequence 19 from typing_extensions import Unpack 20 21 ColumnMapping = t.Union[dict, str, list] 22 23 24@trait 25class Schema(abc.ABC): 26 """Abstract base class for database schemas""" 27 28 @property 29 def dialect(self) -> Dialect | None: 30 """ 31 Returns None by default. Subclasses that require dialect-specific 32 behavior should override this property. 33 """ 34 return None 35 36 @abc.abstractmethod 37 def add_table( 38 self, 39 table: exp.Table | str, 40 column_mapping: ColumnMapping | None = None, 41 dialect: DialectType = None, 42 normalize: bool | None = None, 43 match_depth: bool = True, 44 ) -> None: 45 """ 46 Register or update a table. Some implementing classes may require column information to also be provided. 47 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 48 49 Args: 50 table: the `Table` expression instance or string representing the table. 51 column_mapping: a column mapping that describes the structure of the table. 52 dialect: the SQL dialect that will be used to parse `table` if it's a string. 53 normalize: whether to normalize identifiers according to the dialect of interest. 54 match_depth: whether to enforce that the table must match the schema's depth or not. 55 """ 56 57 @abc.abstractmethod 58 def column_names( 59 self, 60 table: exp.Table | str, 61 only_visible: bool = False, 62 dialect: DialectType = None, 63 normalize: bool | None = None, 64 ) -> Sequence[str]: 65 """ 66 Get the column names for a table. 67 68 Args: 69 table: the `Table` expression instance. 70 only_visible: whether to include invisible columns. 71 dialect: the SQL dialect that will be used to parse `table` if it's a string. 72 normalize: whether to normalize identifiers according to the dialect of interest. 73 74 Returns: 75 The sequence of column names. 76 """ 77 78 @abc.abstractmethod 79 def get_column_type( 80 self, 81 table: exp.Table | str, 82 column: exp.Column | str, 83 dialect: DialectType = None, 84 normalize: bool | None = None, 85 ) -> exp.DataType: 86 """ 87 Get the `sqlglot.exp.DataType` type of a column in the schema. 88 89 Args: 90 table: the source table. 91 column: the target column. 92 dialect: the SQL dialect that will be used to parse `table` if it's a string. 93 normalize: whether to normalize identifiers according to the dialect of interest. 94 95 Returns: 96 The resulting column type. 97 """ 98 99 def has_column( 100 self, 101 table: exp.Table | str, 102 column: exp.Column | str, 103 dialect: DialectType = None, 104 normalize: bool | None = None, 105 ) -> bool: 106 """ 107 Returns whether `column` appears in `table`'s schema. 108 109 Args: 110 table: the source table. 111 column: the target column. 112 dialect: the SQL dialect that will be used to parse `table` if it's a string. 113 normalize: whether to normalize identifiers according to the dialect of interest. 114 115 Returns: 116 True if the column appears in the schema, False otherwise. 117 """ 118 name = column if isinstance(column, str) else column.name 119 return name in self.column_names(table, dialect=dialect, normalize=normalize) 120 121 def get_udf_type( 122 self, 123 udf: exp.Anonymous | str, 124 dialect: DialectType = None, 125 normalize: bool | None = None, 126 ) -> exp.DataType: 127 """ 128 Get the return type of a UDF. 129 130 Args: 131 udf: the UDF expression or string. 132 dialect: the SQL dialect for parsing string arguments. 133 normalize: whether to normalize identifiers. 134 135 Returns: 136 The return type as a DataType, or UNKNOWN if not found. 137 """ 138 return exp.DType.UNKNOWN.into_expr() 139 140 @property 141 @abc.abstractmethod 142 def supported_table_args(self) -> tuple[str, ...]: 143 """ 144 Table arguments this schema support, e.g. `("this", "db", "catalog")` 145 """ 146 147 @property 148 def empty(self) -> bool: 149 """Returns whether the schema is empty.""" 150 return True 151 152 153class AbstractMappingSchema: 154 def __init__( 155 self, 156 mapping: dict[str, object] | None = None, 157 udf_mapping: dict[str, object] | None = None, 158 ) -> None: 159 self.mapping: dict[str, object] = mapping or {} 160 self.mapping_trie: dict[str, object] = new_trie( 161 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 162 ) 163 self.udf_mapping: dict[str, object] = udf_mapping or {} 164 self.udf_trie: dict[str, object] = new_trie( 165 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 166 ) 167 168 self._supported_table_args: tuple[str, ...] = tuple() 169 170 @property 171 def empty(self) -> bool: 172 return not self.mapping 173 174 def depth(self) -> int: 175 return dict_depth(self.mapping) 176 177 def udf_depth(self) -> int: 178 return dict_depth(self.udf_mapping) 179 180 @property 181 def supported_table_args(self) -> tuple[str, ...]: 182 if not self._supported_table_args and self.mapping: 183 depth = self.depth() 184 185 if not depth: # None 186 self._supported_table_args = tuple() 187 elif 1 <= depth <= 3: 188 self._supported_table_args = exp.TABLE_PARTS[:depth] 189 else: 190 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 191 192 return self._supported_table_args 193 194 def table_parts(self, table: exp.Table) -> list[str]: 195 return [p.name for p in reversed(table.parts)] 196 197 def udf_parts(self, udf: exp.Anonymous) -> list[str]: 198 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 199 parent = udf.parent 200 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 201 return list(reversed(parts))[0 : self.udf_depth()] 202 203 def _find_in_trie( 204 self, 205 parts: list[str], 206 trie: dict[str, object], 207 raise_on_missing: bool, 208 ) -> list[str] | None: 209 value, trie = in_trie(trie, parts) 210 211 if value == TrieResult.FAILED: 212 return None 213 214 if value == TrieResult.PREFIX: 215 possibilities = flatten_schema(trie) 216 217 if len(possibilities) == 1: 218 parts.extend(possibilities[0]) 219 else: 220 if raise_on_missing: 221 joined_parts = ".".join(parts) 222 message = ", ".join(".".join(p) for p in possibilities) 223 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 224 225 return None 226 227 return parts 228 229 def find( 230 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 231 ) -> t.Any | None: 232 """ 233 Returns the schema of a given table. 234 235 Args: 236 table: the target table. 237 raise_on_missing: whether to raise in case the schema is not found. 238 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 239 240 Returns: 241 The schema of the target table. 242 """ 243 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 244 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 245 246 if resolved_parts is None: 247 return None 248 249 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 250 251 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None: 252 """ 253 Returns the return type of a given UDF. 254 255 Args: 256 udf: the target UDF expression. 257 raise_on_missing: whether to raise if the UDF is not found. 258 259 Returns: 260 The return type of the UDF, or None if not found. 261 """ 262 parts = self.udf_parts(udf) 263 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 264 265 if resolved_parts is None: 266 return None 267 268 return nested_get( 269 self.udf_mapping, 270 *zip(resolved_parts, reversed(resolved_parts)), 271 raise_on_missing=raise_on_missing, 272 ) 273 274 def nested_get( 275 self, 276 parts: Sequence[str], 277 d: dict[str, object] | None = None, 278 raise_on_missing: bool = True, 279 ) -> t.Any | None: 280 return nested_get( 281 d or self.mapping, 282 *zip(self.supported_table_args, reversed(parts)), 283 raise_on_missing=raise_on_missing, 284 ) 285 286 287class MappingSchema(AbstractMappingSchema, Schema): 288 """ 289 Schema based on a nested mapping. 290 291 Args: 292 schema: Mapping in one of the following forms: 293 1. {table: {col: type}} 294 2. {db: {table: {col: type}}} 295 3. {catalog: {db: {table: {col: type}}}} 296 4. None - Tables will be added later 297 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 298 are assumed to be visible. The nesting should mirror that of the schema: 299 1. {table: set(*cols)}} 300 2. {db: {table: set(*cols)}}} 301 3. {catalog: {db: {table: set(*cols)}}}} 302 dialect: The dialect to be used for custom type mappings & parsing string arguments. 303 normalize: Whether to normalize identifier names according to the given dialect or not. 304 """ 305 306 def __init__( 307 self, 308 schema: dict[str, object] | None = None, 309 visible: dict[str, object] | None = None, 310 dialect: DialectType = None, 311 normalize: bool = True, 312 udf_mapping: dict[str, object] | None = None, 313 ) -> None: 314 self.visible: dict[str, object] = {} if visible is None else visible 315 self.normalize: bool = normalize 316 self._dialect: Dialect = Dialect.get_or_raise(dialect) 317 self._type_mapping_cache: dict[str, exp.DataType] = {} 318 self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {} 319 self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {} 320 self._depth: int = 0 321 schema = {} if schema is None else schema 322 udf_mapping = {} if udf_mapping is None else udf_mapping 323 324 super().__init__( 325 self._normalize(schema) if self.normalize else schema, 326 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 327 ) 328 329 @property 330 def dialect(self) -> Dialect: 331 """Returns the dialect for this mapping schema.""" 332 return self._dialect 333 334 @classmethod 335 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 336 return MappingSchema( 337 schema=mapping_schema.mapping, 338 visible=mapping_schema.visible, 339 dialect=mapping_schema.dialect, 340 normalize=mapping_schema.normalize, 341 udf_mapping=mapping_schema.udf_mapping, 342 ) 343 344 def find( 345 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 346 ) -> t.Any | None: 347 schema = super().find( 348 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 349 ) 350 if ensure_data_types and isinstance(schema, dict): 351 schema = { 352 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 353 for col, dtype in schema.items() 354 } 355 356 return schema 357 358 def copy( 359 self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs] 360 ) -> MappingSchema: 361 mapping_kwargs: SchemaArgs = { 362 "visible": self.visible.copy(), 363 "dialect": self.dialect, 364 "normalize": self.normalize, 365 "udf_mapping": self.udf_mapping.copy(), 366 **kwargs, 367 } 368 return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs) 369 370 def add_table( 371 self, 372 table: exp.Table | str, 373 column_mapping: ColumnMapping | None = None, 374 dialect: DialectType = None, 375 normalize: bool | None = None, 376 match_depth: bool = True, 377 ) -> None: 378 """ 379 Register or update a table. Updates are only performed if a new column mapping is provided. 380 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 381 382 Args: 383 table: the `Table` expression instance or string representing the table. 384 column_mapping: a column mapping that describes the structure of the table. 385 dialect: the SQL dialect that will be used to parse `table` if it's a string. 386 normalize: whether to normalize identifiers according to the dialect of interest. 387 match_depth: whether to enforce that the table must match the schema's depth or not. 388 """ 389 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 390 391 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 392 raise SchemaError( 393 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 394 f"schema's nesting level: {self.depth()}." 395 ) 396 397 normalized_column_mapping = { 398 self._normalize_name(key, dialect=dialect, normalize=normalize): value 399 for key, value in ensure_column_mapping(column_mapping).items() 400 } 401 402 schema = self.find(normalized_table, raise_on_missing=False) 403 if schema and not normalized_column_mapping: 404 return 405 406 parts = self.table_parts(normalized_table) 407 408 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 409 new_trie([parts], self.mapping_trie) 410 411 def column_names( 412 self, 413 table: exp.Table | str, 414 only_visible: bool = False, 415 dialect: DialectType = None, 416 normalize: bool | None = None, 417 ) -> list[str]: 418 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 419 420 schema = self.find(normalized_table) 421 if schema is None: 422 return [] 423 424 if not only_visible or not self.visible: 425 return list(schema) 426 427 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 428 return [col for col in schema if col in visible] 429 430 def get_column_type( 431 self, 432 table: exp.Table | str, 433 column: exp.Column | str, 434 dialect: DialectType = None, 435 normalize: bool | None = None, 436 ) -> exp.DataType: 437 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 438 439 normalized_column_name = self._normalize_name( 440 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 441 ) 442 443 table_schema = self.find(normalized_table, raise_on_missing=False) 444 if table_schema: 445 column_type = table_schema.get(normalized_column_name) 446 447 if isinstance(column_type, exp.DataType): 448 return column_type 449 elif isinstance(column_type, str): 450 return self._to_data_type(column_type, dialect=dialect) 451 452 return exp.DType.UNKNOWN.into_expr() 453 454 def get_udf_type( 455 self, 456 udf: exp.Anonymous | str, 457 dialect: DialectType = None, 458 normalize: bool | None = None, 459 ) -> exp.DataType: 460 """ 461 Get the return type of a UDF. 462 463 Args: 464 udf: the UDF expression or string (e.g., "db.my_func()"). 465 dialect: the SQL dialect for parsing string arguments. 466 normalize: whether to normalize identifiers. 467 468 Returns: 469 The return type as a DataType, or UNKNOWN if not found. 470 """ 471 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 472 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 473 474 if resolved_parts is None: 475 return exp.DType.UNKNOWN.into_expr() 476 477 udf_type = nested_get( 478 self.udf_mapping, 479 *zip(resolved_parts, reversed(resolved_parts)), 480 raise_on_missing=False, 481 ) 482 483 if isinstance(udf_type, exp.DataType): 484 return udf_type 485 elif isinstance(udf_type, str): 486 return self._to_data_type(udf_type, dialect=dialect) 487 488 return exp.DType.UNKNOWN.into_expr() 489 490 def has_column( 491 self, 492 table: exp.Table | str, 493 column: exp.Column | str, 494 dialect: DialectType = None, 495 normalize: bool | None = None, 496 ) -> bool: 497 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 498 499 normalized_column_name = self._normalize_name( 500 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 501 ) 502 503 table_schema = self.find(normalized_table, raise_on_missing=False) 504 return normalized_column_name in table_schema if table_schema else False 505 506 def _normalize(self, schema: dict[str, object]) -> dict[str, object]: 507 """ 508 Normalizes all identifiers in the schema. 509 510 Args: 511 schema: the schema to normalize. 512 513 Returns: 514 The normalized schema mapping. 515 """ 516 normalized_mapping: dict[str, object] = {} 517 flattened_schema = flatten_schema(schema) 518 error_msg = "Table {} must match the schema's nesting level: {}." 519 520 for keys in flattened_schema: 521 columns = nested_get(schema, *zip(keys, keys)) 522 523 if not isinstance(columns, dict): 524 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 525 if not columns: 526 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 527 if isinstance(first(columns.values()), dict): 528 raise SchemaError( 529 error_msg.format( 530 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 531 ), 532 ) 533 534 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 535 for column_name, column_type in columns.items(): 536 nested_set( 537 normalized_mapping, 538 normalized_keys + [self._normalize_name(column_name)], 539 column_type, 540 ) 541 542 return normalized_mapping 543 544 def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]: 545 """ 546 Normalizes all identifiers in the UDF mapping. 547 548 Args: 549 udfs: the UDF mapping to normalize. 550 551 Returns: 552 The normalized UDF mapping. 553 """ 554 normalized_mapping: dict[str, object] = {} 555 556 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 557 udf_type = nested_get(udfs, *zip(keys, keys)) 558 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 559 nested_set(normalized_mapping, normalized_keys, udf_type) 560 561 return normalized_mapping 562 563 def _normalize_udf( 564 self, 565 udf: exp.Anonymous | str, 566 dialect: DialectType = None, 567 normalize: bool | None = None, 568 ) -> list[str]: 569 """ 570 Extract and normalize UDF parts for lookup. 571 572 Args: 573 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 574 dialect: the SQL dialect for parsing. 575 normalize: whether to normalize identifiers. 576 577 Returns: 578 A list of normalized UDF parts (reversed for trie lookup). 579 """ 580 dialect = dialect or self.dialect 581 normalize = self.normalize if normalize is None else normalize 582 583 if isinstance(udf, str): 584 parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect) 585 586 if isinstance(parsed, exp.Anonymous): 587 udf = parsed 588 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 589 udf = parsed.expression 590 else: 591 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 592 parts = self.udf_parts(udf) 593 594 if normalize: 595 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 596 597 return parts 598 599 def _normalize_table( 600 self, 601 table: exp.Table | str, 602 dialect: DialectType = None, 603 normalize: bool | None = None, 604 ) -> exp.Table: 605 dialect = dialect or self.dialect 606 normalize = self.normalize if normalize is None else normalize 607 608 # Cache normalized tables by object id for exp.Table inputs 609 # This is effective when the same Table object is looked up multiple times 610 if isinstance(table, exp.Table) and ( 611 cached := self._normalized_table_cache.get((table, dialect, normalize)) 612 ): 613 return cached 614 615 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 616 617 if normalize: 618 for part in normalized_table.parts: 619 if isinstance(part, exp.Identifier): 620 part.replace( 621 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 622 ) 623 624 self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table 625 return normalized_table 626 627 def _normalize_name( 628 self, 629 name: str | exp.Identifier, 630 dialect: DialectType = None, 631 is_table: bool = False, 632 normalize: bool | None = None, 633 ) -> str: 634 normalize = self.normalize if normalize is None else normalize 635 636 dialect = dialect or self.dialect 637 name_str = name if isinstance(name, str) else name.name 638 cache_key = (name_str, dialect, is_table, normalize) 639 640 if cached := self._normalized_name_cache.get(cache_key): 641 return cached 642 643 result = normalize_name( 644 name, 645 dialect=dialect, 646 is_table=is_table, 647 normalize=normalize, 648 ).name 649 650 self._normalized_name_cache[cache_key] = result 651 return result 652 653 def depth(self) -> int: 654 if not self.empty and not self._depth: 655 # The columns themselves are a mapping, but we don't want to include those 656 self._depth = super().depth() - 1 657 return self._depth 658 659 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 660 """ 661 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 662 663 Args: 664 schema_type: the type we want to convert. 665 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 666 667 Returns: 668 The resulting expression type. 669 """ 670 if schema_type not in self._type_mapping_cache: 671 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 672 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 673 674 try: 675 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 676 expression.transform(dialect.normalize_identifier, copy=False) 677 self._type_mapping_cache[schema_type] = expression 678 except AttributeError: 679 in_dialect = f" in dialect {dialect}" if dialect else "" 680 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 681 682 return self._type_mapping_cache[schema_type] 683 684 685def normalize_name( 686 identifier: str | exp.Identifier, 687 dialect: DialectType = None, 688 is_table: bool = False, 689 normalize: bool | None = True, 690) -> exp.Identifier: 691 if isinstance(identifier, str): 692 identifier = exp.parse_identifier(identifier, dialect=dialect) 693 694 if not normalize: 695 return identifier 696 697 # this is used for normalize_identifier, bigquery has special rules pertaining tables 698 identifier.meta["is_table"] = is_table 699 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 700 701 702def ensure_schema( 703 schema: Schema | dict[str, object] | None, **kwargs: Unpack[SchemaArgs] 704) -> Schema: 705 if isinstance(schema, Schema): 706 return schema 707 708 return MappingSchema(schema, **kwargs) 709 710 711def ensure_column_mapping(mapping: ColumnMapping | None) -> dict: 712 if mapping is None: 713 return {} 714 elif isinstance(mapping, dict): 715 return mapping 716 elif isinstance(mapping, str): 717 col_name_type_strs = [x.strip() for x in mapping.split(",")] 718 return { 719 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 720 for name_type_str in col_name_type_strs 721 } 722 elif isinstance(mapping, list): 723 return {x.strip(): None for x in mapping} 724 725 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 726 727 728def flatten_schema( 729 schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None 730) -> list[list[str]]: 731 tables: list[list[str]] = [] 732 keys = keys or [] 733 depth = dict_depth(schema) - 1 if depth is None else depth 734 735 for k, v in schema.items(): 736 if depth == 1 or not isinstance(v, dict): 737 tables.append(keys + [k]) 738 elif depth >= 2: 739 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 740 741 return tables 742 743 744def nested_get( 745 d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True 746) -> t.Any | None: 747 """ 748 Get a value for a nested dictionary. 749 750 Args: 751 d: the dictionary to search. 752 *path: tuples of (name, key), where: 753 `key` is the key in the dictionary to get. 754 `name` is a string to use in the error if `key` isn't found. 755 756 Returns: 757 The value or None if it doesn't exist. 758 """ 759 result: t.Any = d 760 for name, key in path: 761 result = result.get(key) 762 if result is None: 763 if raise_on_missing: 764 name = "table" if name == "this" else name 765 raise ValueError(f"Unknown {name}: {key}") 766 return None 767 768 return result 769 770 771def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]: 772 """ 773 In-place set a value for a nested dictionary 774 775 Example: 776 >>> nested_set({}, ["top_key", "second_key"], "value") 777 {'top_key': {'second_key': 'value'}} 778 779 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 780 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 781 782 Args: 783 d: dictionary to update. 784 keys: the keys that makeup the path to `value`. 785 value: the value to set in the dictionary for the given key path. 786 787 Returns: 788 The (possibly) updated dictionary. 789 """ 790 if not keys: 791 return d 792 793 if len(keys) == 1: 794 d[keys[0]] = value 795 return d 796 797 subd = d 798 for key in keys[:-1]: 799 if key not in subd: 800 subd = subd.setdefault(key, {}) 801 else: 802 subd = subd[key] 803 804 subd[keys[-1]] = value 805 return d
25@trait 26class Schema(abc.ABC): 27 """Abstract base class for database schemas""" 28 29 @property 30 def dialect(self) -> Dialect | None: 31 """ 32 Returns None by default. Subclasses that require dialect-specific 33 behavior should override this property. 34 """ 35 return None 36 37 @abc.abstractmethod 38 def add_table( 39 self, 40 table: exp.Table | str, 41 column_mapping: ColumnMapping | None = None, 42 dialect: DialectType = None, 43 normalize: bool | None = None, 44 match_depth: bool = True, 45 ) -> None: 46 """ 47 Register or update a table. Some implementing classes may require column information to also be provided. 48 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 49 50 Args: 51 table: the `Table` expression instance or string representing the table. 52 column_mapping: a column mapping that describes the structure of the table. 53 dialect: the SQL dialect that will be used to parse `table` if it's a string. 54 normalize: whether to normalize identifiers according to the dialect of interest. 55 match_depth: whether to enforce that the table must match the schema's depth or not. 56 """ 57 58 @abc.abstractmethod 59 def column_names( 60 self, 61 table: exp.Table | str, 62 only_visible: bool = False, 63 dialect: DialectType = None, 64 normalize: bool | None = None, 65 ) -> Sequence[str]: 66 """ 67 Get the column names for a table. 68 69 Args: 70 table: the `Table` expression instance. 71 only_visible: whether to include invisible columns. 72 dialect: the SQL dialect that will be used to parse `table` if it's a string. 73 normalize: whether to normalize identifiers according to the dialect of interest. 74 75 Returns: 76 The sequence of column names. 77 """ 78 79 @abc.abstractmethod 80 def get_column_type( 81 self, 82 table: exp.Table | str, 83 column: exp.Column | str, 84 dialect: DialectType = None, 85 normalize: bool | None = None, 86 ) -> exp.DataType: 87 """ 88 Get the `sqlglot.exp.DataType` type of a column in the schema. 89 90 Args: 91 table: the source table. 92 column: the target column. 93 dialect: the SQL dialect that will be used to parse `table` if it's a string. 94 normalize: whether to normalize identifiers according to the dialect of interest. 95 96 Returns: 97 The resulting column type. 98 """ 99 100 def has_column( 101 self, 102 table: exp.Table | str, 103 column: exp.Column | str, 104 dialect: DialectType = None, 105 normalize: bool | None = None, 106 ) -> bool: 107 """ 108 Returns whether `column` appears in `table`'s schema. 109 110 Args: 111 table: the source table. 112 column: the target column. 113 dialect: the SQL dialect that will be used to parse `table` if it's a string. 114 normalize: whether to normalize identifiers according to the dialect of interest. 115 116 Returns: 117 True if the column appears in the schema, False otherwise. 118 """ 119 name = column if isinstance(column, str) else column.name 120 return name in self.column_names(table, dialect=dialect, normalize=normalize) 121 122 def get_udf_type( 123 self, 124 udf: exp.Anonymous | str, 125 dialect: DialectType = None, 126 normalize: bool | None = None, 127 ) -> exp.DataType: 128 """ 129 Get the return type of a UDF. 130 131 Args: 132 udf: the UDF expression or string. 133 dialect: the SQL dialect for parsing string arguments. 134 normalize: whether to normalize identifiers. 135 136 Returns: 137 The return type as a DataType, or UNKNOWN if not found. 138 """ 139 return exp.DType.UNKNOWN.into_expr() 140 141 @property 142 @abc.abstractmethod 143 def supported_table_args(self) -> tuple[str, ...]: 144 """ 145 Table arguments this schema support, e.g. `("this", "db", "catalog")` 146 """ 147 148 @property 149 def empty(self) -> bool: 150 """Returns whether the schema is empty.""" 151 return True
Abstract base class for database schemas
29 @property 30 def dialect(self) -> Dialect | None: 31 """ 32 Returns None by default. Subclasses that require dialect-specific 33 behavior should override this property. 34 """ 35 return None
Returns None by default. Subclasses that require dialect-specific behavior should override this property.
37 @abc.abstractmethod 38 def add_table( 39 self, 40 table: exp.Table | str, 41 column_mapping: ColumnMapping | None = None, 42 dialect: DialectType = None, 43 normalize: bool | None = None, 44 match_depth: bool = True, 45 ) -> None: 46 """ 47 Register or update a table. Some implementing classes may require column information to also be provided. 48 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 49 50 Args: 51 table: the `Table` expression instance or string representing the table. 52 column_mapping: a column mapping that describes the structure of the table. 53 dialect: the SQL dialect that will be used to parse `table` if it's a string. 54 normalize: whether to normalize identifiers according to the dialect of interest. 55 match_depth: whether to enforce that the table must match the schema's depth or not. 56 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
58 @abc.abstractmethod 59 def column_names( 60 self, 61 table: exp.Table | str, 62 only_visible: bool = False, 63 dialect: DialectType = None, 64 normalize: bool | None = None, 65 ) -> Sequence[str]: 66 """ 67 Get the column names for a table. 68 69 Args: 70 table: the `Table` expression instance. 71 only_visible: whether to include invisible columns. 72 dialect: the SQL dialect that will be used to parse `table` if it's a string. 73 normalize: whether to normalize identifiers according to the dialect of interest. 74 75 Returns: 76 The sequence of column names. 77 """
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
79 @abc.abstractmethod 80 def get_column_type( 81 self, 82 table: exp.Table | str, 83 column: exp.Column | str, 84 dialect: DialectType = None, 85 normalize: bool | None = None, 86 ) -> exp.DataType: 87 """ 88 Get the `sqlglot.exp.DataType` type of a column in the schema. 89 90 Args: 91 table: the source table. 92 column: the target column. 93 dialect: the SQL dialect that will be used to parse `table` if it's a string. 94 normalize: whether to normalize identifiers according to the dialect of interest. 95 96 Returns: 97 The resulting column type. 98 """
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
100 def has_column( 101 self, 102 table: exp.Table | str, 103 column: exp.Column | str, 104 dialect: DialectType = None, 105 normalize: bool | None = None, 106 ) -> bool: 107 """ 108 Returns whether `column` appears in `table`'s schema. 109 110 Args: 111 table: the source table. 112 column: the target column. 113 dialect: the SQL dialect that will be used to parse `table` if it's a string. 114 normalize: whether to normalize identifiers according to the dialect of interest. 115 116 Returns: 117 True if the column appears in the schema, False otherwise. 118 """ 119 name = column if isinstance(column, str) else column.name 120 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
122 def get_udf_type( 123 self, 124 udf: exp.Anonymous | str, 125 dialect: DialectType = None, 126 normalize: bool | None = None, 127 ) -> exp.DataType: 128 """ 129 Get the return type of a UDF. 130 131 Args: 132 udf: the UDF expression or string. 133 dialect: the SQL dialect for parsing string arguments. 134 normalize: whether to normalize identifiers. 135 136 Returns: 137 The return type as a DataType, or UNKNOWN if not found. 138 """ 139 return exp.DType.UNKNOWN.into_expr()
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string.
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
154class AbstractMappingSchema: 155 def __init__( 156 self, 157 mapping: dict[str, object] | None = None, 158 udf_mapping: dict[str, object] | None = None, 159 ) -> None: 160 self.mapping: dict[str, object] = mapping or {} 161 self.mapping_trie: dict[str, object] = new_trie( 162 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 163 ) 164 self.udf_mapping: dict[str, object] = udf_mapping or {} 165 self.udf_trie: dict[str, object] = new_trie( 166 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 167 ) 168 169 self._supported_table_args: tuple[str, ...] = tuple() 170 171 @property 172 def empty(self) -> bool: 173 return not self.mapping 174 175 def depth(self) -> int: 176 return dict_depth(self.mapping) 177 178 def udf_depth(self) -> int: 179 return dict_depth(self.udf_mapping) 180 181 @property 182 def supported_table_args(self) -> tuple[str, ...]: 183 if not self._supported_table_args and self.mapping: 184 depth = self.depth() 185 186 if not depth: # None 187 self._supported_table_args = tuple() 188 elif 1 <= depth <= 3: 189 self._supported_table_args = exp.TABLE_PARTS[:depth] 190 else: 191 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 192 193 return self._supported_table_args 194 195 def table_parts(self, table: exp.Table) -> list[str]: 196 return [p.name for p in reversed(table.parts)] 197 198 def udf_parts(self, udf: exp.Anonymous) -> list[str]: 199 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 200 parent = udf.parent 201 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 202 return list(reversed(parts))[0 : self.udf_depth()] 203 204 def _find_in_trie( 205 self, 206 parts: list[str], 207 trie: dict[str, object], 208 raise_on_missing: bool, 209 ) -> list[str] | None: 210 value, trie = in_trie(trie, parts) 211 212 if value == TrieResult.FAILED: 213 return None 214 215 if value == TrieResult.PREFIX: 216 possibilities = flatten_schema(trie) 217 218 if len(possibilities) == 1: 219 parts.extend(possibilities[0]) 220 else: 221 if raise_on_missing: 222 joined_parts = ".".join(parts) 223 message = ", ".join(".".join(p) for p in possibilities) 224 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 225 226 return None 227 228 return parts 229 230 def find( 231 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 232 ) -> t.Any | None: 233 """ 234 Returns the schema of a given table. 235 236 Args: 237 table: the target table. 238 raise_on_missing: whether to raise in case the schema is not found. 239 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 240 241 Returns: 242 The schema of the target table. 243 """ 244 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 245 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 246 247 if resolved_parts is None: 248 return None 249 250 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 251 252 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None: 253 """ 254 Returns the return type of a given UDF. 255 256 Args: 257 udf: the target UDF expression. 258 raise_on_missing: whether to raise if the UDF is not found. 259 260 Returns: 261 The return type of the UDF, or None if not found. 262 """ 263 parts = self.udf_parts(udf) 264 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 265 266 if resolved_parts is None: 267 return None 268 269 return nested_get( 270 self.udf_mapping, 271 *zip(resolved_parts, reversed(resolved_parts)), 272 raise_on_missing=raise_on_missing, 273 ) 274 275 def nested_get( 276 self, 277 parts: Sequence[str], 278 d: dict[str, object] | None = None, 279 raise_on_missing: bool = True, 280 ) -> t.Any | None: 281 return nested_get( 282 d or self.mapping, 283 *zip(self.supported_table_args, reversed(parts)), 284 raise_on_missing=raise_on_missing, 285 )
155 def __init__( 156 self, 157 mapping: dict[str, object] | None = None, 158 udf_mapping: dict[str, object] | None = None, 159 ) -> None: 160 self.mapping: dict[str, object] = mapping or {} 161 self.mapping_trie: dict[str, object] = new_trie( 162 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 163 ) 164 self.udf_mapping: dict[str, object] = udf_mapping or {} 165 self.udf_trie: dict[str, object] = new_trie( 166 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 167 ) 168 169 self._supported_table_args: tuple[str, ...] = tuple()
181 @property 182 def supported_table_args(self) -> tuple[str, ...]: 183 if not self._supported_table_args and self.mapping: 184 depth = self.depth() 185 186 if not depth: # None 187 self._supported_table_args = tuple() 188 elif 1 <= depth <= 3: 189 self._supported_table_args = exp.TABLE_PARTS[:depth] 190 else: 191 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 192 193 return self._supported_table_args
230 def find( 231 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 232 ) -> t.Any | None: 233 """ 234 Returns the schema of a given table. 235 236 Args: 237 table: the target table. 238 raise_on_missing: whether to raise in case the schema is not found. 239 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 240 241 Returns: 242 The schema of the target table. 243 """ 244 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 245 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 246 247 if resolved_parts is None: 248 return None 249 250 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
252 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None: 253 """ 254 Returns the return type of a given UDF. 255 256 Args: 257 udf: the target UDF expression. 258 raise_on_missing: whether to raise if the UDF is not found. 259 260 Returns: 261 The return type of the UDF, or None if not found. 262 """ 263 parts = self.udf_parts(udf) 264 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 265 266 if resolved_parts is None: 267 return None 268 269 return nested_get( 270 self.udf_mapping, 271 *zip(resolved_parts, reversed(resolved_parts)), 272 raise_on_missing=raise_on_missing, 273 )
Returns the return type of a given UDF.
Arguments:
- udf: the target UDF expression.
- raise_on_missing: whether to raise if the UDF is not found.
Returns:
The return type of the UDF, or None if not found.
288class MappingSchema(AbstractMappingSchema, Schema): 289 """ 290 Schema based on a nested mapping. 291 292 Args: 293 schema: Mapping in one of the following forms: 294 1. {table: {col: type}} 295 2. {db: {table: {col: type}}} 296 3. {catalog: {db: {table: {col: type}}}} 297 4. None - Tables will be added later 298 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 299 are assumed to be visible. The nesting should mirror that of the schema: 300 1. {table: set(*cols)}} 301 2. {db: {table: set(*cols)}}} 302 3. {catalog: {db: {table: set(*cols)}}}} 303 dialect: The dialect to be used for custom type mappings & parsing string arguments. 304 normalize: Whether to normalize identifier names according to the given dialect or not. 305 """ 306 307 def __init__( 308 self, 309 schema: dict[str, object] | None = None, 310 visible: dict[str, object] | None = None, 311 dialect: DialectType = None, 312 normalize: bool = True, 313 udf_mapping: dict[str, object] | None = None, 314 ) -> None: 315 self.visible: dict[str, object] = {} if visible is None else visible 316 self.normalize: bool = normalize 317 self._dialect: Dialect = Dialect.get_or_raise(dialect) 318 self._type_mapping_cache: dict[str, exp.DataType] = {} 319 self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {} 320 self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {} 321 self._depth: int = 0 322 schema = {} if schema is None else schema 323 udf_mapping = {} if udf_mapping is None else udf_mapping 324 325 super().__init__( 326 self._normalize(schema) if self.normalize else schema, 327 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 328 ) 329 330 @property 331 def dialect(self) -> Dialect: 332 """Returns the dialect for this mapping schema.""" 333 return self._dialect 334 335 @classmethod 336 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 337 return MappingSchema( 338 schema=mapping_schema.mapping, 339 visible=mapping_schema.visible, 340 dialect=mapping_schema.dialect, 341 normalize=mapping_schema.normalize, 342 udf_mapping=mapping_schema.udf_mapping, 343 ) 344 345 def find( 346 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 347 ) -> t.Any | None: 348 schema = super().find( 349 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 350 ) 351 if ensure_data_types and isinstance(schema, dict): 352 schema = { 353 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 354 for col, dtype in schema.items() 355 } 356 357 return schema 358 359 def copy( 360 self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs] 361 ) -> MappingSchema: 362 mapping_kwargs: SchemaArgs = { 363 "visible": self.visible.copy(), 364 "dialect": self.dialect, 365 "normalize": self.normalize, 366 "udf_mapping": self.udf_mapping.copy(), 367 **kwargs, 368 } 369 return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs) 370 371 def add_table( 372 self, 373 table: exp.Table | str, 374 column_mapping: ColumnMapping | None = None, 375 dialect: DialectType = None, 376 normalize: bool | None = None, 377 match_depth: bool = True, 378 ) -> None: 379 """ 380 Register or update a table. Updates are only performed if a new column mapping is provided. 381 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 382 383 Args: 384 table: the `Table` expression instance or string representing the table. 385 column_mapping: a column mapping that describes the structure of the table. 386 dialect: the SQL dialect that will be used to parse `table` if it's a string. 387 normalize: whether to normalize identifiers according to the dialect of interest. 388 match_depth: whether to enforce that the table must match the schema's depth or not. 389 """ 390 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 391 392 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 393 raise SchemaError( 394 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 395 f"schema's nesting level: {self.depth()}." 396 ) 397 398 normalized_column_mapping = { 399 self._normalize_name(key, dialect=dialect, normalize=normalize): value 400 for key, value in ensure_column_mapping(column_mapping).items() 401 } 402 403 schema = self.find(normalized_table, raise_on_missing=False) 404 if schema and not normalized_column_mapping: 405 return 406 407 parts = self.table_parts(normalized_table) 408 409 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 410 new_trie([parts], self.mapping_trie) 411 412 def column_names( 413 self, 414 table: exp.Table | str, 415 only_visible: bool = False, 416 dialect: DialectType = None, 417 normalize: bool | None = None, 418 ) -> list[str]: 419 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 420 421 schema = self.find(normalized_table) 422 if schema is None: 423 return [] 424 425 if not only_visible or not self.visible: 426 return list(schema) 427 428 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 429 return [col for col in schema if col in visible] 430 431 def get_column_type( 432 self, 433 table: exp.Table | str, 434 column: exp.Column | str, 435 dialect: DialectType = None, 436 normalize: bool | None = None, 437 ) -> exp.DataType: 438 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 439 440 normalized_column_name = self._normalize_name( 441 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 442 ) 443 444 table_schema = self.find(normalized_table, raise_on_missing=False) 445 if table_schema: 446 column_type = table_schema.get(normalized_column_name) 447 448 if isinstance(column_type, exp.DataType): 449 return column_type 450 elif isinstance(column_type, str): 451 return self._to_data_type(column_type, dialect=dialect) 452 453 return exp.DType.UNKNOWN.into_expr() 454 455 def get_udf_type( 456 self, 457 udf: exp.Anonymous | str, 458 dialect: DialectType = None, 459 normalize: bool | None = None, 460 ) -> exp.DataType: 461 """ 462 Get the return type of a UDF. 463 464 Args: 465 udf: the UDF expression or string (e.g., "db.my_func()"). 466 dialect: the SQL dialect for parsing string arguments. 467 normalize: whether to normalize identifiers. 468 469 Returns: 470 The return type as a DataType, or UNKNOWN if not found. 471 """ 472 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 473 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 474 475 if resolved_parts is None: 476 return exp.DType.UNKNOWN.into_expr() 477 478 udf_type = nested_get( 479 self.udf_mapping, 480 *zip(resolved_parts, reversed(resolved_parts)), 481 raise_on_missing=False, 482 ) 483 484 if isinstance(udf_type, exp.DataType): 485 return udf_type 486 elif isinstance(udf_type, str): 487 return self._to_data_type(udf_type, dialect=dialect) 488 489 return exp.DType.UNKNOWN.into_expr() 490 491 def has_column( 492 self, 493 table: exp.Table | str, 494 column: exp.Column | str, 495 dialect: DialectType = None, 496 normalize: bool | None = None, 497 ) -> bool: 498 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 499 500 normalized_column_name = self._normalize_name( 501 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 502 ) 503 504 table_schema = self.find(normalized_table, raise_on_missing=False) 505 return normalized_column_name in table_schema if table_schema else False 506 507 def _normalize(self, schema: dict[str, object]) -> dict[str, object]: 508 """ 509 Normalizes all identifiers in the schema. 510 511 Args: 512 schema: the schema to normalize. 513 514 Returns: 515 The normalized schema mapping. 516 """ 517 normalized_mapping: dict[str, object] = {} 518 flattened_schema = flatten_schema(schema) 519 error_msg = "Table {} must match the schema's nesting level: {}." 520 521 for keys in flattened_schema: 522 columns = nested_get(schema, *zip(keys, keys)) 523 524 if not isinstance(columns, dict): 525 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 526 if not columns: 527 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 528 if isinstance(first(columns.values()), dict): 529 raise SchemaError( 530 error_msg.format( 531 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 532 ), 533 ) 534 535 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 536 for column_name, column_type in columns.items(): 537 nested_set( 538 normalized_mapping, 539 normalized_keys + [self._normalize_name(column_name)], 540 column_type, 541 ) 542 543 return normalized_mapping 544 545 def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]: 546 """ 547 Normalizes all identifiers in the UDF mapping. 548 549 Args: 550 udfs: the UDF mapping to normalize. 551 552 Returns: 553 The normalized UDF mapping. 554 """ 555 normalized_mapping: dict[str, object] = {} 556 557 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 558 udf_type = nested_get(udfs, *zip(keys, keys)) 559 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 560 nested_set(normalized_mapping, normalized_keys, udf_type) 561 562 return normalized_mapping 563 564 def _normalize_udf( 565 self, 566 udf: exp.Anonymous | str, 567 dialect: DialectType = None, 568 normalize: bool | None = None, 569 ) -> list[str]: 570 """ 571 Extract and normalize UDF parts for lookup. 572 573 Args: 574 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 575 dialect: the SQL dialect for parsing. 576 normalize: whether to normalize identifiers. 577 578 Returns: 579 A list of normalized UDF parts (reversed for trie lookup). 580 """ 581 dialect = dialect or self.dialect 582 normalize = self.normalize if normalize is None else normalize 583 584 if isinstance(udf, str): 585 parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect) 586 587 if isinstance(parsed, exp.Anonymous): 588 udf = parsed 589 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 590 udf = parsed.expression 591 else: 592 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 593 parts = self.udf_parts(udf) 594 595 if normalize: 596 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 597 598 return parts 599 600 def _normalize_table( 601 self, 602 table: exp.Table | str, 603 dialect: DialectType = None, 604 normalize: bool | None = None, 605 ) -> exp.Table: 606 dialect = dialect or self.dialect 607 normalize = self.normalize if normalize is None else normalize 608 609 # Cache normalized tables by object id for exp.Table inputs 610 # This is effective when the same Table object is looked up multiple times 611 if isinstance(table, exp.Table) and ( 612 cached := self._normalized_table_cache.get((table, dialect, normalize)) 613 ): 614 return cached 615 616 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 617 618 if normalize: 619 for part in normalized_table.parts: 620 if isinstance(part, exp.Identifier): 621 part.replace( 622 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 623 ) 624 625 self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table 626 return normalized_table 627 628 def _normalize_name( 629 self, 630 name: str | exp.Identifier, 631 dialect: DialectType = None, 632 is_table: bool = False, 633 normalize: bool | None = None, 634 ) -> str: 635 normalize = self.normalize if normalize is None else normalize 636 637 dialect = dialect or self.dialect 638 name_str = name if isinstance(name, str) else name.name 639 cache_key = (name_str, dialect, is_table, normalize) 640 641 if cached := self._normalized_name_cache.get(cache_key): 642 return cached 643 644 result = normalize_name( 645 name, 646 dialect=dialect, 647 is_table=is_table, 648 normalize=normalize, 649 ).name 650 651 self._normalized_name_cache[cache_key] = result 652 return result 653 654 def depth(self) -> int: 655 if not self.empty and not self._depth: 656 # The columns themselves are a mapping, but we don't want to include those 657 self._depth = super().depth() - 1 658 return self._depth 659 660 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 661 """ 662 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 663 664 Args: 665 schema_type: the type we want to convert. 666 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 667 668 Returns: 669 The resulting expression type. 670 """ 671 if schema_type not in self._type_mapping_cache: 672 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 673 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 674 675 try: 676 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 677 expression.transform(dialect.normalize_identifier, copy=False) 678 self._type_mapping_cache[schema_type] = expression 679 except AttributeError: 680 in_dialect = f" in dialect {dialect}" if dialect else "" 681 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 682 683 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(*cols)}}
- {db: {table: set(*cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
307 def __init__( 308 self, 309 schema: dict[str, object] | None = None, 310 visible: dict[str, object] | None = None, 311 dialect: DialectType = None, 312 normalize: bool = True, 313 udf_mapping: dict[str, object] | None = None, 314 ) -> None: 315 self.visible: dict[str, object] = {} if visible is None else visible 316 self.normalize: bool = normalize 317 self._dialect: Dialect = Dialect.get_or_raise(dialect) 318 self._type_mapping_cache: dict[str, exp.DataType] = {} 319 self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {} 320 self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {} 321 self._depth: int = 0 322 schema = {} if schema is None else schema 323 udf_mapping = {} if udf_mapping is None else udf_mapping 324 325 super().__init__( 326 self._normalize(schema) if self.normalize else schema, 327 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 328 )
330 @property 331 def dialect(self) -> Dialect: 332 """Returns the dialect for this mapping schema.""" 333 return self._dialect
Returns the dialect for this mapping schema.
335 @classmethod 336 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 337 return MappingSchema( 338 schema=mapping_schema.mapping, 339 visible=mapping_schema.visible, 340 dialect=mapping_schema.dialect, 341 normalize=mapping_schema.normalize, 342 udf_mapping=mapping_schema.udf_mapping, 343 )
345 def find( 346 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 347 ) -> t.Any | None: 348 schema = super().find( 349 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 350 ) 351 if ensure_data_types and isinstance(schema, dict): 352 schema = { 353 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 354 for col, dtype in schema.items() 355 } 356 357 return schema
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
359 def copy( 360 self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs] 361 ) -> MappingSchema: 362 mapping_kwargs: SchemaArgs = { 363 "visible": self.visible.copy(), 364 "dialect": self.dialect, 365 "normalize": self.normalize, 366 "udf_mapping": self.udf_mapping.copy(), 367 **kwargs, 368 } 369 return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
371 def add_table( 372 self, 373 table: exp.Table | str, 374 column_mapping: ColumnMapping | None = None, 375 dialect: DialectType = None, 376 normalize: bool | None = None, 377 match_depth: bool = True, 378 ) -> None: 379 """ 380 Register or update a table. Updates are only performed if a new column mapping is provided. 381 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 382 383 Args: 384 table: the `Table` expression instance or string representing the table. 385 column_mapping: a column mapping that describes the structure of the table. 386 dialect: the SQL dialect that will be used to parse `table` if it's a string. 387 normalize: whether to normalize identifiers according to the dialect of interest. 388 match_depth: whether to enforce that the table must match the schema's depth or not. 389 """ 390 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 391 392 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 393 raise SchemaError( 394 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 395 f"schema's nesting level: {self.depth()}." 396 ) 397 398 normalized_column_mapping = { 399 self._normalize_name(key, dialect=dialect, normalize=normalize): value 400 for key, value in ensure_column_mapping(column_mapping).items() 401 } 402 403 schema = self.find(normalized_table, raise_on_missing=False) 404 if schema and not normalized_column_mapping: 405 return 406 407 parts = self.table_parts(normalized_table) 408 409 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 410 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
412 def column_names( 413 self, 414 table: exp.Table | str, 415 only_visible: bool = False, 416 dialect: DialectType = None, 417 normalize: bool | None = None, 418 ) -> list[str]: 419 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 420 421 schema = self.find(normalized_table) 422 if schema is None: 423 return [] 424 425 if not only_visible or not self.visible: 426 return list(schema) 427 428 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 429 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
431 def get_column_type( 432 self, 433 table: exp.Table | str, 434 column: exp.Column | str, 435 dialect: DialectType = None, 436 normalize: bool | None = None, 437 ) -> exp.DataType: 438 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 439 440 normalized_column_name = self._normalize_name( 441 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 442 ) 443 444 table_schema = self.find(normalized_table, raise_on_missing=False) 445 if table_schema: 446 column_type = table_schema.get(normalized_column_name) 447 448 if isinstance(column_type, exp.DataType): 449 return column_type 450 elif isinstance(column_type, str): 451 return self._to_data_type(column_type, dialect=dialect) 452 453 return exp.DType.UNKNOWN.into_expr()
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
455 def get_udf_type( 456 self, 457 udf: exp.Anonymous | str, 458 dialect: DialectType = None, 459 normalize: bool | None = None, 460 ) -> exp.DataType: 461 """ 462 Get the return type of a UDF. 463 464 Args: 465 udf: the UDF expression or string (e.g., "db.my_func()"). 466 dialect: the SQL dialect for parsing string arguments. 467 normalize: whether to normalize identifiers. 468 469 Returns: 470 The return type as a DataType, or UNKNOWN if not found. 471 """ 472 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 473 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 474 475 if resolved_parts is None: 476 return exp.DType.UNKNOWN.into_expr() 477 478 udf_type = nested_get( 479 self.udf_mapping, 480 *zip(resolved_parts, reversed(resolved_parts)), 481 raise_on_missing=False, 482 ) 483 484 if isinstance(udf_type, exp.DataType): 485 return udf_type 486 elif isinstance(udf_type, str): 487 return self._to_data_type(udf_type, dialect=dialect) 488 489 return exp.DType.UNKNOWN.into_expr()
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string (e.g., "db.my_func()").
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
491 def has_column( 492 self, 493 table: exp.Table | str, 494 column: exp.Column | str, 495 dialect: DialectType = None, 496 normalize: bool | None = None, 497 ) -> bool: 498 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 499 500 normalized_column_name = self._normalize_name( 501 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 502 ) 503 504 table_schema = self.find(normalized_table, raise_on_missing=False) 505 return normalized_column_name in table_schema if table_schema else False
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
686def normalize_name( 687 identifier: str | exp.Identifier, 688 dialect: DialectType = None, 689 is_table: bool = False, 690 normalize: bool | None = True, 691) -> exp.Identifier: 692 if isinstance(identifier, str): 693 identifier = exp.parse_identifier(identifier, dialect=dialect) 694 695 if not normalize: 696 return identifier 697 698 # this is used for normalize_identifier, bigquery has special rules pertaining tables 699 identifier.meta["is_table"] = is_table 700 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
712def ensure_column_mapping(mapping: ColumnMapping | None) -> dict: 713 if mapping is None: 714 return {} 715 elif isinstance(mapping, dict): 716 return mapping 717 elif isinstance(mapping, str): 718 col_name_type_strs = [x.strip() for x in mapping.split(",")] 719 return { 720 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 721 for name_type_str in col_name_type_strs 722 } 723 elif isinstance(mapping, list): 724 return {x.strip(): None for x in mapping} 725 726 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
729def flatten_schema( 730 schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None 731) -> list[list[str]]: 732 tables: list[list[str]] = [] 733 keys = keys or [] 734 depth = dict_depth(schema) - 1 if depth is None else depth 735 736 for k, v in schema.items(): 737 if depth == 1 or not isinstance(v, dict): 738 tables.append(keys + [k]) 739 elif depth >= 2: 740 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 741 742 return tables
745def nested_get( 746 d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True 747) -> t.Any | None: 748 """ 749 Get a value for a nested dictionary. 750 751 Args: 752 d: the dictionary to search. 753 *path: tuples of (name, key), where: 754 `key` is the key in the dictionary to get. 755 `name` is a string to use in the error if `key` isn't found. 756 757 Returns: 758 The value or None if it doesn't exist. 759 """ 760 result: t.Any = d 761 for name, key in path: 762 result = result.get(key) 763 if result is None: 764 if raise_on_missing: 765 name = "table" if name == "this" else name 766 raise ValueError(f"Unknown {name}: {key}") 767 return None 768 769 return result
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
keyis the key in the dictionary to get.nameis a string to use in the error ifkeyisn't found.
Returns:
The value or None if it doesn't exist.
772def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]: 773 """ 774 In-place set a value for a nested dictionary 775 776 Example: 777 >>> nested_set({}, ["top_key", "second_key"], "value") 778 {'top_key': {'second_key': 'value'}} 779 780 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 781 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 782 783 Args: 784 d: dictionary to update. 785 keys: the keys that makeup the path to `value`. 786 value: the value to set in the dictionary for the given key path. 787 788 Returns: 789 The (possibly) updated dictionary. 790 """ 791 if not keys: 792 return d 793 794 if len(keys) == 1: 795 d[keys[0]] = value 796 return d 797 798 subd = d 799 for key in keys[:-1]: 800 if key not in subd: 801 subd = subd.setdefault(key, {}) 802 else: 803 subd = subd[key] 804 805 subd[keys[-1]] = value 806 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.