sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15 ColumnMapping = t.Union[t.Dict, str, t.List] 16 17 18class Schema(abc.ABC): 19 """Abstract base class for database schemas""" 20 21 @property 22 def dialect(self) -> t.Optional[Dialect]: 23 """ 24 Returns None by default. Subclasses that require dialect-specific 25 behavior should override this property. 26 """ 27 return None 28 29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 match_depth: bool = True, 37 ) -> None: 38 """ 39 Register or update a table. Some implementing classes may require column information to also be provided. 40 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 41 42 Args: 43 table: the `Table` expression instance or string representing the table. 44 column_mapping: a column mapping that describes the structure of the table. 45 dialect: the SQL dialect that will be used to parse `table` if it's a string. 46 normalize: whether to normalize identifiers according to the dialect of interest. 47 match_depth: whether to enforce that the table must match the schema's depth or not. 48 """ 49 50 @abc.abstractmethod 51 def column_names( 52 self, 53 table: exp.Table | str, 54 only_visible: bool = False, 55 dialect: DialectType = None, 56 normalize: t.Optional[bool] = None, 57 ) -> t.Sequence[str]: 58 """ 59 Get the column names for a table. 60 61 Args: 62 table: the `Table` expression instance. 63 only_visible: whether to include invisible columns. 64 dialect: the SQL dialect that will be used to parse `table` if it's a string. 65 normalize: whether to normalize identifiers according to the dialect of interest. 66 67 Returns: 68 The sequence of column names. 69 """ 70 71 @abc.abstractmethod 72 def get_column_type( 73 self, 74 table: exp.Table | str, 75 column: exp.Column | str, 76 dialect: DialectType = None, 77 normalize: t.Optional[bool] = None, 78 ) -> exp.DataType: 79 """ 80 Get the `sqlglot.exp.DataType` type of a column in the schema. 81 82 Args: 83 table: the source table. 84 column: the target column. 85 dialect: the SQL dialect that will be used to parse `table` if it's a string. 86 normalize: whether to normalize identifiers according to the dialect of interest. 87 88 Returns: 89 The resulting column type. 90 """ 91 92 def has_column( 93 self, 94 table: exp.Table | str, 95 column: exp.Column | str, 96 dialect: DialectType = None, 97 normalize: t.Optional[bool] = None, 98 ) -> bool: 99 """ 100 Returns whether `column` appears in `table`'s schema. 101 102 Args: 103 table: the source table. 104 column: the target column. 105 dialect: the SQL dialect that will be used to parse `table` if it's a string. 106 normalize: whether to normalize identifiers according to the dialect of interest. 107 108 Returns: 109 True if the column appears in the schema, False otherwise. 110 """ 111 name = column if isinstance(column, str) else column.name 112 return name in self.column_names(table, dialect=dialect, normalize=normalize) 113 114 def get_udf_type( 115 self, 116 udf: exp.Anonymous | str, 117 dialect: DialectType = None, 118 normalize: t.Optional[bool] = None, 119 ) -> exp.DataType: 120 """ 121 Get the return type of a UDF. 122 123 Args: 124 udf: the UDF expression or string. 125 dialect: the SQL dialect for parsing string arguments. 126 normalize: whether to normalize identifiers. 127 128 Returns: 129 The return type as a DataType, or UNKNOWN if not found. 130 """ 131 return exp.DataType.build("unknown") 132 133 @property 134 @abc.abstractmethod 135 def supported_table_args(self) -> t.Tuple[str, ...]: 136 """ 137 Table arguments this schema support, e.g. `("this", "db", "catalog")` 138 """ 139 140 @property 141 def empty(self) -> bool: 142 """Returns whether the schema is empty.""" 143 return True 144 145 146class AbstractMappingSchema: 147 def __init__( 148 self, 149 mapping: t.Optional[t.Dict] = None, 150 udf_mapping: t.Optional[t.Dict] = None, 151 ) -> None: 152 self.mapping = mapping or {} 153 self.mapping_trie = new_trie( 154 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 155 ) 156 157 self.udf_mapping = udf_mapping or {} 158 self.udf_trie = new_trie( 159 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 160 ) 161 162 self._supported_table_args: t.Tuple[str, ...] = tuple() 163 164 @property 165 def empty(self) -> bool: 166 return not self.mapping 167 168 def depth(self) -> int: 169 return dict_depth(self.mapping) 170 171 def udf_depth(self) -> int: 172 return dict_depth(self.udf_mapping) 173 174 @property 175 def supported_table_args(self) -> t.Tuple[str, ...]: 176 if not self._supported_table_args and self.mapping: 177 depth = self.depth() 178 179 if not depth: # None 180 self._supported_table_args = tuple() 181 elif 1 <= depth <= 3: 182 self._supported_table_args = exp.TABLE_PARTS[:depth] 183 else: 184 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 185 186 return self._supported_table_args 187 188 def table_parts(self, table: exp.Table) -> t.List[str]: 189 return [p.name for p in reversed(table.parts)] 190 191 def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: 192 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 193 parent = udf.parent 194 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 195 return list(reversed(parts))[0 : self.udf_depth()] 196 197 def _find_in_trie( 198 self, 199 parts: t.List[str], 200 trie: t.Dict, 201 raise_on_missing: bool, 202 ) -> t.Optional[t.List[str]]: 203 value, trie = in_trie(trie, parts) 204 205 if value == TrieResult.FAILED: 206 return None 207 208 if value == TrieResult.PREFIX: 209 possibilities = flatten_schema(trie) 210 211 if len(possibilities) == 1: 212 parts.extend(possibilities[0]) 213 else: 214 if raise_on_missing: 215 joined_parts = ".".join(parts) 216 message = ", ".join(".".join(p) for p in possibilities) 217 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 218 219 return None 220 221 return parts 222 223 def find( 224 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 225 ) -> t.Optional[t.Any]: 226 """ 227 Returns the schema of a given table. 228 229 Args: 230 table: the target table. 231 raise_on_missing: whether to raise in case the schema is not found. 232 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 233 234 Returns: 235 The schema of the target table. 236 """ 237 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 238 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 239 240 if resolved_parts is None: 241 return None 242 243 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 244 245 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]: 246 """ 247 Returns the return type of a given UDF. 248 249 Args: 250 udf: the target UDF expression. 251 raise_on_missing: whether to raise if the UDF is not found. 252 253 Returns: 254 The return type of the UDF, or None if not found. 255 """ 256 parts = self.udf_parts(udf) 257 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 258 259 if resolved_parts is None: 260 return None 261 262 return nested_get( 263 self.udf_mapping, 264 *zip(resolved_parts, reversed(resolved_parts)), 265 raise_on_missing=raise_on_missing, 266 ) 267 268 def nested_get( 269 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 270 ) -> t.Optional[t.Any]: 271 return nested_get( 272 d or self.mapping, 273 *zip(self.supported_table_args, reversed(parts)), 274 raise_on_missing=raise_on_missing, 275 ) 276 277 278class MappingSchema(AbstractMappingSchema, Schema): 279 """ 280 Schema based on a nested mapping. 281 282 Args: 283 schema: Mapping in one of the following forms: 284 1. {table: {col: type}} 285 2. {db: {table: {col: type}}} 286 3. {catalog: {db: {table: {col: type}}}} 287 4. None - Tables will be added later 288 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 289 are assumed to be visible. The nesting should mirror that of the schema: 290 1. {table: set(*cols)}} 291 2. {db: {table: set(*cols)}}} 292 3. {catalog: {db: {table: set(*cols)}}}} 293 dialect: The dialect to be used for custom type mappings & parsing string arguments. 294 normalize: Whether to normalize identifier names according to the given dialect or not. 295 """ 296 297 def __init__( 298 self, 299 schema: t.Optional[t.Dict] = None, 300 visible: t.Optional[t.Dict] = None, 301 dialect: DialectType = None, 302 normalize: bool = True, 303 udf_mapping: t.Optional[t.Dict] = None, 304 ) -> None: 305 self.visible = {} if visible is None else visible 306 self.normalize = normalize 307 self._dialect = Dialect.get_or_raise(dialect) 308 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 309 self._depth = 0 310 schema = {} if schema is None else schema 311 udf_mapping = {} if udf_mapping is None else udf_mapping 312 313 super().__init__( 314 self._normalize(schema) if self.normalize else schema, 315 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 316 ) 317 318 @property 319 def dialect(self) -> Dialect: 320 """Returns the dialect for this mapping schema.""" 321 return self._dialect 322 323 @classmethod 324 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 325 return MappingSchema( 326 schema=mapping_schema.mapping, 327 visible=mapping_schema.visible, 328 dialect=mapping_schema.dialect, 329 normalize=mapping_schema.normalize, 330 udf_mapping=mapping_schema.udf_mapping, 331 ) 332 333 def find( 334 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 335 ) -> t.Optional[t.Any]: 336 schema = super().find( 337 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 338 ) 339 if ensure_data_types and isinstance(schema, dict): 340 schema = { 341 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 342 for col, dtype in schema.items() 343 } 344 345 return schema 346 347 def copy(self, **kwargs) -> MappingSchema: 348 return MappingSchema( 349 **{ # type: ignore 350 "schema": self.mapping.copy(), 351 "visible": self.visible.copy(), 352 "dialect": self.dialect, 353 "normalize": self.normalize, 354 "udf_mapping": self.udf_mapping.copy(), 355 **kwargs, 356 } 357 ) 358 359 def add_table( 360 self, 361 table: exp.Table | str, 362 column_mapping: t.Optional[ColumnMapping] = None, 363 dialect: DialectType = None, 364 normalize: t.Optional[bool] = None, 365 match_depth: bool = True, 366 ) -> None: 367 """ 368 Register or update a table. Updates are only performed if a new column mapping is provided. 369 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 370 371 Args: 372 table: the `Table` expression instance or string representing the table. 373 column_mapping: a column mapping that describes the structure of the table. 374 dialect: the SQL dialect that will be used to parse `table` if it's a string. 375 normalize: whether to normalize identifiers according to the dialect of interest. 376 match_depth: whether to enforce that the table must match the schema's depth or not. 377 """ 378 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 379 380 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 381 raise SchemaError( 382 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 383 f"schema's nesting level: {self.depth()}." 384 ) 385 386 normalized_column_mapping = { 387 self._normalize_name(key, dialect=dialect, normalize=normalize): value 388 for key, value in ensure_column_mapping(column_mapping).items() 389 } 390 391 schema = self.find(normalized_table, raise_on_missing=False) 392 if schema and not normalized_column_mapping: 393 return 394 395 parts = self.table_parts(normalized_table) 396 397 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 398 new_trie([parts], self.mapping_trie) 399 400 def column_names( 401 self, 402 table: exp.Table | str, 403 only_visible: bool = False, 404 dialect: DialectType = None, 405 normalize: t.Optional[bool] = None, 406 ) -> t.List[str]: 407 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 408 409 schema = self.find(normalized_table) 410 if schema is None: 411 return [] 412 413 if not only_visible or not self.visible: 414 return list(schema) 415 416 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 417 return [col for col in schema if col in visible] 418 419 def get_column_type( 420 self, 421 table: exp.Table | str, 422 column: exp.Column | str, 423 dialect: DialectType = None, 424 normalize: t.Optional[bool] = None, 425 ) -> exp.DataType: 426 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 427 428 normalized_column_name = self._normalize_name( 429 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 430 ) 431 432 table_schema = self.find(normalized_table, raise_on_missing=False) 433 if table_schema: 434 column_type = table_schema.get(normalized_column_name) 435 436 if isinstance(column_type, exp.DataType): 437 return column_type 438 elif isinstance(column_type, str): 439 return self._to_data_type(column_type, dialect=dialect) 440 441 return exp.DataType.build("unknown") 442 443 def get_udf_type( 444 self, 445 udf: exp.Anonymous | str, 446 dialect: DialectType = None, 447 normalize: t.Optional[bool] = None, 448 ) -> exp.DataType: 449 """ 450 Get the return type of a UDF. 451 452 Args: 453 udf: the UDF expression or string (e.g., "db.my_func()"). 454 dialect: the SQL dialect for parsing string arguments. 455 normalize: whether to normalize identifiers. 456 457 Returns: 458 The return type as a DataType, or UNKNOWN if not found. 459 """ 460 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 461 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 462 463 if resolved_parts is None: 464 return exp.DataType.build("unknown") 465 466 udf_type = nested_get( 467 self.udf_mapping, 468 *zip(resolved_parts, reversed(resolved_parts)), 469 raise_on_missing=False, 470 ) 471 472 if isinstance(udf_type, exp.DataType): 473 return udf_type 474 elif isinstance(udf_type, str): 475 return self._to_data_type(udf_type, dialect=dialect) 476 477 return exp.DataType.build("unknown") 478 479 def has_column( 480 self, 481 table: exp.Table | str, 482 column: exp.Column | str, 483 dialect: DialectType = None, 484 normalize: t.Optional[bool] = None, 485 ) -> bool: 486 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 487 488 normalized_column_name = self._normalize_name( 489 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 490 ) 491 492 table_schema = self.find(normalized_table, raise_on_missing=False) 493 return normalized_column_name in table_schema if table_schema else False 494 495 def _normalize(self, schema: t.Dict) -> t.Dict: 496 """ 497 Normalizes all identifiers in the schema. 498 499 Args: 500 schema: the schema to normalize. 501 502 Returns: 503 The normalized schema mapping. 504 """ 505 normalized_mapping: t.Dict = {} 506 flattened_schema = flatten_schema(schema) 507 error_msg = "Table {} must match the schema's nesting level: {}." 508 509 for keys in flattened_schema: 510 columns = nested_get(schema, *zip(keys, keys)) 511 512 if not isinstance(columns, dict): 513 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 514 if not columns: 515 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 516 if isinstance(first(columns.values()), dict): 517 raise SchemaError( 518 error_msg.format( 519 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 520 ), 521 ) 522 523 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 524 for column_name, column_type in columns.items(): 525 nested_set( 526 normalized_mapping, 527 normalized_keys + [self._normalize_name(column_name)], 528 column_type, 529 ) 530 531 return normalized_mapping 532 533 def _normalize_udfs(self, udfs: t.Dict) -> t.Dict: 534 """ 535 Normalizes all identifiers in the UDF mapping. 536 537 Args: 538 udfs: the UDF mapping to normalize. 539 540 Returns: 541 The normalized UDF mapping. 542 """ 543 normalized_mapping: t.Dict = {} 544 545 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 546 udf_type = nested_get(udfs, *zip(keys, keys)) 547 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 548 nested_set(normalized_mapping, normalized_keys, udf_type) 549 550 return normalized_mapping 551 552 def _normalize_udf( 553 self, 554 udf: exp.Anonymous | str, 555 dialect: DialectType = None, 556 normalize: t.Optional[bool] = None, 557 ) -> t.List[str]: 558 """ 559 Extract and normalize UDF parts for lookup. 560 561 Args: 562 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 563 dialect: the SQL dialect for parsing. 564 normalize: whether to normalize identifiers. 565 566 Returns: 567 A list of normalized UDF parts (reversed for trie lookup). 568 """ 569 dialect = dialect or self.dialect 570 normalize = self.normalize if normalize is None else normalize 571 572 if isinstance(udf, str): 573 parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect) 574 575 if isinstance(parsed, exp.Anonymous): 576 udf = parsed 577 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 578 udf = parsed.expression 579 else: 580 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 581 parts = self.udf_parts(udf) 582 583 if normalize: 584 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 585 586 return parts 587 588 def _normalize_table( 589 self, 590 table: exp.Table | str, 591 dialect: DialectType = None, 592 normalize: t.Optional[bool] = None, 593 ) -> exp.Table: 594 dialect = dialect or self.dialect 595 normalize = self.normalize if normalize is None else normalize 596 597 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 598 599 if normalize: 600 for part in normalized_table.parts: 601 if isinstance(part, exp.Identifier): 602 part.replace( 603 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 604 ) 605 606 return normalized_table 607 608 def _normalize_name( 609 self, 610 name: str | exp.Identifier, 611 dialect: DialectType = None, 612 is_table: bool = False, 613 normalize: t.Optional[bool] = None, 614 ) -> str: 615 return normalize_name( 616 name, 617 dialect=dialect or self.dialect, 618 is_table=is_table, 619 normalize=self.normalize if normalize is None else normalize, 620 ).name 621 622 def depth(self) -> int: 623 if not self.empty and not self._depth: 624 # The columns themselves are a mapping, but we don't want to include those 625 self._depth = super().depth() - 1 626 return self._depth 627 628 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 629 """ 630 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 631 632 Args: 633 schema_type: the type we want to convert. 634 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 635 636 Returns: 637 The resulting expression type. 638 """ 639 if schema_type not in self._type_mapping_cache: 640 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 641 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 642 643 try: 644 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 645 expression.transform(dialect.normalize_identifier, copy=False) 646 self._type_mapping_cache[schema_type] = expression 647 except AttributeError: 648 in_dialect = f" in dialect {dialect}" if dialect else "" 649 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 650 651 return self._type_mapping_cache[schema_type] 652 653 654def normalize_name( 655 identifier: str | exp.Identifier, 656 dialect: DialectType = None, 657 is_table: bool = False, 658 normalize: t.Optional[bool] = True, 659) -> exp.Identifier: 660 if isinstance(identifier, str): 661 identifier = exp.parse_identifier(identifier, dialect=dialect) 662 663 if not normalize: 664 return identifier 665 666 # this is used for normalize_identifier, bigquery has special rules pertaining tables 667 identifier.meta["is_table"] = is_table 668 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 669 670 671def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 672 if isinstance(schema, Schema): 673 return schema 674 675 return MappingSchema(schema, **kwargs) 676 677 678def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 679 if mapping is None: 680 return {} 681 elif isinstance(mapping, dict): 682 return mapping 683 elif isinstance(mapping, str): 684 col_name_type_strs = [x.strip() for x in mapping.split(",")] 685 return { 686 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 687 for name_type_str in col_name_type_strs 688 } 689 elif isinstance(mapping, list): 690 return {x.strip(): None for x in mapping} 691 692 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 693 694 695def flatten_schema( 696 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 697) -> t.List[t.List[str]]: 698 tables = [] 699 keys = keys or [] 700 depth = dict_depth(schema) - 1 if depth is None else depth 701 702 for k, v in schema.items(): 703 if depth == 1 or not isinstance(v, dict): 704 tables.append(keys + [k]) 705 elif depth >= 2: 706 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 707 708 return tables 709 710 711def nested_get( 712 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 713) -> t.Optional[t.Any]: 714 """ 715 Get a value for a nested dictionary. 716 717 Args: 718 d: the dictionary to search. 719 *path: tuples of (name, key), where: 720 `key` is the key in the dictionary to get. 721 `name` is a string to use in the error if `key` isn't found. 722 723 Returns: 724 The value or None if it doesn't exist. 725 """ 726 for name, key in path: 727 d = d.get(key) # type: ignore 728 if d is None: 729 if raise_on_missing: 730 name = "table" if name == "this" else name 731 raise ValueError(f"Unknown {name}: {key}") 732 return None 733 734 return d 735 736 737def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 738 """ 739 In-place set a value for a nested dictionary 740 741 Example: 742 >>> nested_set({}, ["top_key", "second_key"], "value") 743 {'top_key': {'second_key': 'value'}} 744 745 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 746 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 747 748 Args: 749 d: dictionary to update. 750 keys: the keys that makeup the path to `value`. 751 value: the value to set in the dictionary for the given key path. 752 753 Returns: 754 The (possibly) updated dictionary. 755 """ 756 if not keys: 757 return d 758 759 if len(keys) == 1: 760 d[keys[0]] = value 761 return d 762 763 subd = d 764 for key in keys[:-1]: 765 if key not in subd: 766 subd = subd.setdefault(key, {}) 767 else: 768 subd = subd[key] 769 770 subd[keys[-1]] = value 771 return d
19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 @property 23 def dialect(self) -> t.Optional[Dialect]: 24 """ 25 Returns None by default. Subclasses that require dialect-specific 26 behavior should override this property. 27 """ 28 return None 29 30 @abc.abstractmethod 31 def add_table( 32 self, 33 table: exp.Table | str, 34 column_mapping: t.Optional[ColumnMapping] = None, 35 dialect: DialectType = None, 36 normalize: t.Optional[bool] = None, 37 match_depth: bool = True, 38 ) -> None: 39 """ 40 Register or update a table. Some implementing classes may require column information to also be provided. 41 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 42 43 Args: 44 table: the `Table` expression instance or string representing the table. 45 column_mapping: a column mapping that describes the structure of the table. 46 dialect: the SQL dialect that will be used to parse `table` if it's a string. 47 normalize: whether to normalize identifiers according to the dialect of interest. 48 match_depth: whether to enforce that the table must match the schema's depth or not. 49 """ 50 51 @abc.abstractmethod 52 def column_names( 53 self, 54 table: exp.Table | str, 55 only_visible: bool = False, 56 dialect: DialectType = None, 57 normalize: t.Optional[bool] = None, 58 ) -> t.Sequence[str]: 59 """ 60 Get the column names for a table. 61 62 Args: 63 table: the `Table` expression instance. 64 only_visible: whether to include invisible columns. 65 dialect: the SQL dialect that will be used to parse `table` if it's a string. 66 normalize: whether to normalize identifiers according to the dialect of interest. 67 68 Returns: 69 The sequence of column names. 70 """ 71 72 @abc.abstractmethod 73 def get_column_type( 74 self, 75 table: exp.Table | str, 76 column: exp.Column | str, 77 dialect: DialectType = None, 78 normalize: t.Optional[bool] = None, 79 ) -> exp.DataType: 80 """ 81 Get the `sqlglot.exp.DataType` type of a column in the schema. 82 83 Args: 84 table: the source table. 85 column: the target column. 86 dialect: the SQL dialect that will be used to parse `table` if it's a string. 87 normalize: whether to normalize identifiers according to the dialect of interest. 88 89 Returns: 90 The resulting column type. 91 """ 92 93 def has_column( 94 self, 95 table: exp.Table | str, 96 column: exp.Column | str, 97 dialect: DialectType = None, 98 normalize: t.Optional[bool] = None, 99 ) -> bool: 100 """ 101 Returns whether `column` appears in `table`'s schema. 102 103 Args: 104 table: the source table. 105 column: the target column. 106 dialect: the SQL dialect that will be used to parse `table` if it's a string. 107 normalize: whether to normalize identifiers according to the dialect of interest. 108 109 Returns: 110 True if the column appears in the schema, False otherwise. 111 """ 112 name = column if isinstance(column, str) else column.name 113 return name in self.column_names(table, dialect=dialect, normalize=normalize) 114 115 def get_udf_type( 116 self, 117 udf: exp.Anonymous | str, 118 dialect: DialectType = None, 119 normalize: t.Optional[bool] = None, 120 ) -> exp.DataType: 121 """ 122 Get the return type of a UDF. 123 124 Args: 125 udf: the UDF expression or string. 126 dialect: the SQL dialect for parsing string arguments. 127 normalize: whether to normalize identifiers. 128 129 Returns: 130 The return type as a DataType, or UNKNOWN if not found. 131 """ 132 return exp.DataType.build("unknown") 133 134 @property 135 @abc.abstractmethod 136 def supported_table_args(self) -> t.Tuple[str, ...]: 137 """ 138 Table arguments this schema support, e.g. `("this", "db", "catalog")` 139 """ 140 141 @property 142 def empty(self) -> bool: 143 """Returns whether the schema is empty.""" 144 return True
Abstract base class for database schemas
22 @property 23 def dialect(self) -> t.Optional[Dialect]: 24 """ 25 Returns None by default. Subclasses that require dialect-specific 26 behavior should override this property. 27 """ 28 return None
Returns None by default. Subclasses that require dialect-specific behavior should override this property.
30 @abc.abstractmethod 31 def add_table( 32 self, 33 table: exp.Table | str, 34 column_mapping: t.Optional[ColumnMapping] = None, 35 dialect: DialectType = None, 36 normalize: t.Optional[bool] = None, 37 match_depth: bool = True, 38 ) -> None: 39 """ 40 Register or update a table. Some implementing classes may require column information to also be provided. 41 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 42 43 Args: 44 table: the `Table` expression instance or string representing the table. 45 column_mapping: a column mapping that describes the structure of the table. 46 dialect: the SQL dialect that will be used to parse `table` if it's a string. 47 normalize: whether to normalize identifiers according to the dialect of interest. 48 match_depth: whether to enforce that the table must match the schema's depth or not. 49 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
51 @abc.abstractmethod 52 def column_names( 53 self, 54 table: exp.Table | str, 55 only_visible: bool = False, 56 dialect: DialectType = None, 57 normalize: t.Optional[bool] = None, 58 ) -> t.Sequence[str]: 59 """ 60 Get the column names for a table. 61 62 Args: 63 table: the `Table` expression instance. 64 only_visible: whether to include invisible columns. 65 dialect: the SQL dialect that will be used to parse `table` if it's a string. 66 normalize: whether to normalize identifiers according to the dialect of interest. 67 68 Returns: 69 The sequence of column names. 70 """
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
72 @abc.abstractmethod 73 def get_column_type( 74 self, 75 table: exp.Table | str, 76 column: exp.Column | str, 77 dialect: DialectType = None, 78 normalize: t.Optional[bool] = None, 79 ) -> exp.DataType: 80 """ 81 Get the `sqlglot.exp.DataType` type of a column in the schema. 82 83 Args: 84 table: the source table. 85 column: the target column. 86 dialect: the SQL dialect that will be used to parse `table` if it's a string. 87 normalize: whether to normalize identifiers according to the dialect of interest. 88 89 Returns: 90 The resulting column type. 91 """
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
93 def has_column( 94 self, 95 table: exp.Table | str, 96 column: exp.Column | str, 97 dialect: DialectType = None, 98 normalize: t.Optional[bool] = None, 99 ) -> bool: 100 """ 101 Returns whether `column` appears in `table`'s schema. 102 103 Args: 104 table: the source table. 105 column: the target column. 106 dialect: the SQL dialect that will be used to parse `table` if it's a string. 107 normalize: whether to normalize identifiers according to the dialect of interest. 108 109 Returns: 110 True if the column appears in the schema, False otherwise. 111 """ 112 name = column if isinstance(column, str) else column.name 113 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
115 def get_udf_type( 116 self, 117 udf: exp.Anonymous | str, 118 dialect: DialectType = None, 119 normalize: t.Optional[bool] = None, 120 ) -> exp.DataType: 121 """ 122 Get the return type of a UDF. 123 124 Args: 125 udf: the UDF expression or string. 126 dialect: the SQL dialect for parsing string arguments. 127 normalize: whether to normalize identifiers. 128 129 Returns: 130 The return type as a DataType, or UNKNOWN if not found. 131 """ 132 return exp.DataType.build("unknown")
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string.
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
147class AbstractMappingSchema: 148 def __init__( 149 self, 150 mapping: t.Optional[t.Dict] = None, 151 udf_mapping: t.Optional[t.Dict] = None, 152 ) -> None: 153 self.mapping = mapping or {} 154 self.mapping_trie = new_trie( 155 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 156 ) 157 158 self.udf_mapping = udf_mapping or {} 159 self.udf_trie = new_trie( 160 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 161 ) 162 163 self._supported_table_args: t.Tuple[str, ...] = tuple() 164 165 @property 166 def empty(self) -> bool: 167 return not self.mapping 168 169 def depth(self) -> int: 170 return dict_depth(self.mapping) 171 172 def udf_depth(self) -> int: 173 return dict_depth(self.udf_mapping) 174 175 @property 176 def supported_table_args(self) -> t.Tuple[str, ...]: 177 if not self._supported_table_args and self.mapping: 178 depth = self.depth() 179 180 if not depth: # None 181 self._supported_table_args = tuple() 182 elif 1 <= depth <= 3: 183 self._supported_table_args = exp.TABLE_PARTS[:depth] 184 else: 185 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 186 187 return self._supported_table_args 188 189 def table_parts(self, table: exp.Table) -> t.List[str]: 190 return [p.name for p in reversed(table.parts)] 191 192 def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: 193 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 194 parent = udf.parent 195 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 196 return list(reversed(parts))[0 : self.udf_depth()] 197 198 def _find_in_trie( 199 self, 200 parts: t.List[str], 201 trie: t.Dict, 202 raise_on_missing: bool, 203 ) -> t.Optional[t.List[str]]: 204 value, trie = in_trie(trie, parts) 205 206 if value == TrieResult.FAILED: 207 return None 208 209 if value == TrieResult.PREFIX: 210 possibilities = flatten_schema(trie) 211 212 if len(possibilities) == 1: 213 parts.extend(possibilities[0]) 214 else: 215 if raise_on_missing: 216 joined_parts = ".".join(parts) 217 message = ", ".join(".".join(p) for p in possibilities) 218 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 219 220 return None 221 222 return parts 223 224 def find( 225 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 226 ) -> t.Optional[t.Any]: 227 """ 228 Returns the schema of a given table. 229 230 Args: 231 table: the target table. 232 raise_on_missing: whether to raise in case the schema is not found. 233 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 234 235 Returns: 236 The schema of the target table. 237 """ 238 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 239 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 240 241 if resolved_parts is None: 242 return None 243 244 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 245 246 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]: 247 """ 248 Returns the return type of a given UDF. 249 250 Args: 251 udf: the target UDF expression. 252 raise_on_missing: whether to raise if the UDF is not found. 253 254 Returns: 255 The return type of the UDF, or None if not found. 256 """ 257 parts = self.udf_parts(udf) 258 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 259 260 if resolved_parts is None: 261 return None 262 263 return nested_get( 264 self.udf_mapping, 265 *zip(resolved_parts, reversed(resolved_parts)), 266 raise_on_missing=raise_on_missing, 267 ) 268 269 def nested_get( 270 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 271 ) -> t.Optional[t.Any]: 272 return nested_get( 273 d or self.mapping, 274 *zip(self.supported_table_args, reversed(parts)), 275 raise_on_missing=raise_on_missing, 276 )
148 def __init__( 149 self, 150 mapping: t.Optional[t.Dict] = None, 151 udf_mapping: t.Optional[t.Dict] = None, 152 ) -> None: 153 self.mapping = mapping or {} 154 self.mapping_trie = new_trie( 155 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 156 ) 157 158 self.udf_mapping = udf_mapping or {} 159 self.udf_trie = new_trie( 160 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 161 ) 162 163 self._supported_table_args: t.Tuple[str, ...] = tuple()
175 @property 176 def supported_table_args(self) -> t.Tuple[str, ...]: 177 if not self._supported_table_args and self.mapping: 178 depth = self.depth() 179 180 if not depth: # None 181 self._supported_table_args = tuple() 182 elif 1 <= depth <= 3: 183 self._supported_table_args = exp.TABLE_PARTS[:depth] 184 else: 185 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 186 187 return self._supported_table_args
192 def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: 193 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 194 parent = udf.parent 195 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 196 return list(reversed(parts))[0 : self.udf_depth()]
224 def find( 225 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 226 ) -> t.Optional[t.Any]: 227 """ 228 Returns the schema of a given table. 229 230 Args: 231 table: the target table. 232 raise_on_missing: whether to raise in case the schema is not found. 233 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 234 235 Returns: 236 The schema of the target table. 237 """ 238 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 239 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 240 241 if resolved_parts is None: 242 return None 243 244 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
246 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]: 247 """ 248 Returns the return type of a given UDF. 249 250 Args: 251 udf: the target UDF expression. 252 raise_on_missing: whether to raise if the UDF is not found. 253 254 Returns: 255 The return type of the UDF, or None if not found. 256 """ 257 parts = self.udf_parts(udf) 258 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 259 260 if resolved_parts is None: 261 return None 262 263 return nested_get( 264 self.udf_mapping, 265 *zip(resolved_parts, reversed(resolved_parts)), 266 raise_on_missing=raise_on_missing, 267 )
Returns the return type of a given UDF.
Arguments:
- udf: the target UDF expression.
- raise_on_missing: whether to raise if the UDF is not found.
Returns:
The return type of the UDF, or None if not found.
279class MappingSchema(AbstractMappingSchema, Schema): 280 """ 281 Schema based on a nested mapping. 282 283 Args: 284 schema: Mapping in one of the following forms: 285 1. {table: {col: type}} 286 2. {db: {table: {col: type}}} 287 3. {catalog: {db: {table: {col: type}}}} 288 4. None - Tables will be added later 289 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 290 are assumed to be visible. The nesting should mirror that of the schema: 291 1. {table: set(*cols)}} 292 2. {db: {table: set(*cols)}}} 293 3. {catalog: {db: {table: set(*cols)}}}} 294 dialect: The dialect to be used for custom type mappings & parsing string arguments. 295 normalize: Whether to normalize identifier names according to the given dialect or not. 296 """ 297 298 def __init__( 299 self, 300 schema: t.Optional[t.Dict] = None, 301 visible: t.Optional[t.Dict] = None, 302 dialect: DialectType = None, 303 normalize: bool = True, 304 udf_mapping: t.Optional[t.Dict] = None, 305 ) -> None: 306 self.visible = {} if visible is None else visible 307 self.normalize = normalize 308 self._dialect = Dialect.get_or_raise(dialect) 309 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 310 self._depth = 0 311 schema = {} if schema is None else schema 312 udf_mapping = {} if udf_mapping is None else udf_mapping 313 314 super().__init__( 315 self._normalize(schema) if self.normalize else schema, 316 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 317 ) 318 319 @property 320 def dialect(self) -> Dialect: 321 """Returns the dialect for this mapping schema.""" 322 return self._dialect 323 324 @classmethod 325 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 326 return MappingSchema( 327 schema=mapping_schema.mapping, 328 visible=mapping_schema.visible, 329 dialect=mapping_schema.dialect, 330 normalize=mapping_schema.normalize, 331 udf_mapping=mapping_schema.udf_mapping, 332 ) 333 334 def find( 335 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 336 ) -> t.Optional[t.Any]: 337 schema = super().find( 338 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 339 ) 340 if ensure_data_types and isinstance(schema, dict): 341 schema = { 342 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 343 for col, dtype in schema.items() 344 } 345 346 return schema 347 348 def copy(self, **kwargs) -> MappingSchema: 349 return MappingSchema( 350 **{ # type: ignore 351 "schema": self.mapping.copy(), 352 "visible": self.visible.copy(), 353 "dialect": self.dialect, 354 "normalize": self.normalize, 355 "udf_mapping": self.udf_mapping.copy(), 356 **kwargs, 357 } 358 ) 359 360 def add_table( 361 self, 362 table: exp.Table | str, 363 column_mapping: t.Optional[ColumnMapping] = None, 364 dialect: DialectType = None, 365 normalize: t.Optional[bool] = None, 366 match_depth: bool = True, 367 ) -> None: 368 """ 369 Register or update a table. Updates are only performed if a new column mapping is provided. 370 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 371 372 Args: 373 table: the `Table` expression instance or string representing the table. 374 column_mapping: a column mapping that describes the structure of the table. 375 dialect: the SQL dialect that will be used to parse `table` if it's a string. 376 normalize: whether to normalize identifiers according to the dialect of interest. 377 match_depth: whether to enforce that the table must match the schema's depth or not. 378 """ 379 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 380 381 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 382 raise SchemaError( 383 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 384 f"schema's nesting level: {self.depth()}." 385 ) 386 387 normalized_column_mapping = { 388 self._normalize_name(key, dialect=dialect, normalize=normalize): value 389 for key, value in ensure_column_mapping(column_mapping).items() 390 } 391 392 schema = self.find(normalized_table, raise_on_missing=False) 393 if schema and not normalized_column_mapping: 394 return 395 396 parts = self.table_parts(normalized_table) 397 398 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 399 new_trie([parts], self.mapping_trie) 400 401 def column_names( 402 self, 403 table: exp.Table | str, 404 only_visible: bool = False, 405 dialect: DialectType = None, 406 normalize: t.Optional[bool] = None, 407 ) -> t.List[str]: 408 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 409 410 schema = self.find(normalized_table) 411 if schema is None: 412 return [] 413 414 if not only_visible or not self.visible: 415 return list(schema) 416 417 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 418 return [col for col in schema if col in visible] 419 420 def get_column_type( 421 self, 422 table: exp.Table | str, 423 column: exp.Column | str, 424 dialect: DialectType = None, 425 normalize: t.Optional[bool] = None, 426 ) -> exp.DataType: 427 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 428 429 normalized_column_name = self._normalize_name( 430 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 431 ) 432 433 table_schema = self.find(normalized_table, raise_on_missing=False) 434 if table_schema: 435 column_type = table_schema.get(normalized_column_name) 436 437 if isinstance(column_type, exp.DataType): 438 return column_type 439 elif isinstance(column_type, str): 440 return self._to_data_type(column_type, dialect=dialect) 441 442 return exp.DataType.build("unknown") 443 444 def get_udf_type( 445 self, 446 udf: exp.Anonymous | str, 447 dialect: DialectType = None, 448 normalize: t.Optional[bool] = None, 449 ) -> exp.DataType: 450 """ 451 Get the return type of a UDF. 452 453 Args: 454 udf: the UDF expression or string (e.g., "db.my_func()"). 455 dialect: the SQL dialect for parsing string arguments. 456 normalize: whether to normalize identifiers. 457 458 Returns: 459 The return type as a DataType, or UNKNOWN if not found. 460 """ 461 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 462 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 463 464 if resolved_parts is None: 465 return exp.DataType.build("unknown") 466 467 udf_type = nested_get( 468 self.udf_mapping, 469 *zip(resolved_parts, reversed(resolved_parts)), 470 raise_on_missing=False, 471 ) 472 473 if isinstance(udf_type, exp.DataType): 474 return udf_type 475 elif isinstance(udf_type, str): 476 return self._to_data_type(udf_type, dialect=dialect) 477 478 return exp.DataType.build("unknown") 479 480 def has_column( 481 self, 482 table: exp.Table | str, 483 column: exp.Column | str, 484 dialect: DialectType = None, 485 normalize: t.Optional[bool] = None, 486 ) -> bool: 487 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 488 489 normalized_column_name = self._normalize_name( 490 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 491 ) 492 493 table_schema = self.find(normalized_table, raise_on_missing=False) 494 return normalized_column_name in table_schema if table_schema else False 495 496 def _normalize(self, schema: t.Dict) -> t.Dict: 497 """ 498 Normalizes all identifiers in the schema. 499 500 Args: 501 schema: the schema to normalize. 502 503 Returns: 504 The normalized schema mapping. 505 """ 506 normalized_mapping: t.Dict = {} 507 flattened_schema = flatten_schema(schema) 508 error_msg = "Table {} must match the schema's nesting level: {}." 509 510 for keys in flattened_schema: 511 columns = nested_get(schema, *zip(keys, keys)) 512 513 if not isinstance(columns, dict): 514 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 515 if not columns: 516 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 517 if isinstance(first(columns.values()), dict): 518 raise SchemaError( 519 error_msg.format( 520 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 521 ), 522 ) 523 524 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 525 for column_name, column_type in columns.items(): 526 nested_set( 527 normalized_mapping, 528 normalized_keys + [self._normalize_name(column_name)], 529 column_type, 530 ) 531 532 return normalized_mapping 533 534 def _normalize_udfs(self, udfs: t.Dict) -> t.Dict: 535 """ 536 Normalizes all identifiers in the UDF mapping. 537 538 Args: 539 udfs: the UDF mapping to normalize. 540 541 Returns: 542 The normalized UDF mapping. 543 """ 544 normalized_mapping: t.Dict = {} 545 546 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 547 udf_type = nested_get(udfs, *zip(keys, keys)) 548 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 549 nested_set(normalized_mapping, normalized_keys, udf_type) 550 551 return normalized_mapping 552 553 def _normalize_udf( 554 self, 555 udf: exp.Anonymous | str, 556 dialect: DialectType = None, 557 normalize: t.Optional[bool] = None, 558 ) -> t.List[str]: 559 """ 560 Extract and normalize UDF parts for lookup. 561 562 Args: 563 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 564 dialect: the SQL dialect for parsing. 565 normalize: whether to normalize identifiers. 566 567 Returns: 568 A list of normalized UDF parts (reversed for trie lookup). 569 """ 570 dialect = dialect or self.dialect 571 normalize = self.normalize if normalize is None else normalize 572 573 if isinstance(udf, str): 574 parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect) 575 576 if isinstance(parsed, exp.Anonymous): 577 udf = parsed 578 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 579 udf = parsed.expression 580 else: 581 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 582 parts = self.udf_parts(udf) 583 584 if normalize: 585 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 586 587 return parts 588 589 def _normalize_table( 590 self, 591 table: exp.Table | str, 592 dialect: DialectType = None, 593 normalize: t.Optional[bool] = None, 594 ) -> exp.Table: 595 dialect = dialect or self.dialect 596 normalize = self.normalize if normalize is None else normalize 597 598 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 599 600 if normalize: 601 for part in normalized_table.parts: 602 if isinstance(part, exp.Identifier): 603 part.replace( 604 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 605 ) 606 607 return normalized_table 608 609 def _normalize_name( 610 self, 611 name: str | exp.Identifier, 612 dialect: DialectType = None, 613 is_table: bool = False, 614 normalize: t.Optional[bool] = None, 615 ) -> str: 616 return normalize_name( 617 name, 618 dialect=dialect or self.dialect, 619 is_table=is_table, 620 normalize=self.normalize if normalize is None else normalize, 621 ).name 622 623 def depth(self) -> int: 624 if not self.empty and not self._depth: 625 # The columns themselves are a mapping, but we don't want to include those 626 self._depth = super().depth() - 1 627 return self._depth 628 629 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 630 """ 631 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 632 633 Args: 634 schema_type: the type we want to convert. 635 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 636 637 Returns: 638 The resulting expression type. 639 """ 640 if schema_type not in self._type_mapping_cache: 641 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 642 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 643 644 try: 645 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 646 expression.transform(dialect.normalize_identifier, copy=False) 647 self._type_mapping_cache[schema_type] = expression 648 except AttributeError: 649 in_dialect = f" in dialect {dialect}" if dialect else "" 650 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 651 652 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
298 def __init__( 299 self, 300 schema: t.Optional[t.Dict] = None, 301 visible: t.Optional[t.Dict] = None, 302 dialect: DialectType = None, 303 normalize: bool = True, 304 udf_mapping: t.Optional[t.Dict] = None, 305 ) -> None: 306 self.visible = {} if visible is None else visible 307 self.normalize = normalize 308 self._dialect = Dialect.get_or_raise(dialect) 309 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 310 self._depth = 0 311 schema = {} if schema is None else schema 312 udf_mapping = {} if udf_mapping is None else udf_mapping 313 314 super().__init__( 315 self._normalize(schema) if self.normalize else schema, 316 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 317 )
319 @property 320 def dialect(self) -> Dialect: 321 """Returns the dialect for this mapping schema.""" 322 return self._dialect
Returns the dialect for this mapping schema.
324 @classmethod 325 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 326 return MappingSchema( 327 schema=mapping_schema.mapping, 328 visible=mapping_schema.visible, 329 dialect=mapping_schema.dialect, 330 normalize=mapping_schema.normalize, 331 udf_mapping=mapping_schema.udf_mapping, 332 )
334 def find( 335 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 336 ) -> t.Optional[t.Any]: 337 schema = super().find( 338 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 339 ) 340 if ensure_data_types and isinstance(schema, dict): 341 schema = { 342 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 343 for col, dtype in schema.items() 344 } 345 346 return schema
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
360 def add_table( 361 self, 362 table: exp.Table | str, 363 column_mapping: t.Optional[ColumnMapping] = None, 364 dialect: DialectType = None, 365 normalize: t.Optional[bool] = None, 366 match_depth: bool = True, 367 ) -> None: 368 """ 369 Register or update a table. Updates are only performed if a new column mapping is provided. 370 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 371 372 Args: 373 table: the `Table` expression instance or string representing the table. 374 column_mapping: a column mapping that describes the structure of the table. 375 dialect: the SQL dialect that will be used to parse `table` if it's a string. 376 normalize: whether to normalize identifiers according to the dialect of interest. 377 match_depth: whether to enforce that the table must match the schema's depth or not. 378 """ 379 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 380 381 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 382 raise SchemaError( 383 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 384 f"schema's nesting level: {self.depth()}." 385 ) 386 387 normalized_column_mapping = { 388 self._normalize_name(key, dialect=dialect, normalize=normalize): value 389 for key, value in ensure_column_mapping(column_mapping).items() 390 } 391 392 schema = self.find(normalized_table, raise_on_missing=False) 393 if schema and not normalized_column_mapping: 394 return 395 396 parts = self.table_parts(normalized_table) 397 398 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 399 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
401 def column_names( 402 self, 403 table: exp.Table | str, 404 only_visible: bool = False, 405 dialect: DialectType = None, 406 normalize: t.Optional[bool] = None, 407 ) -> t.List[str]: 408 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 409 410 schema = self.find(normalized_table) 411 if schema is None: 412 return [] 413 414 if not only_visible or not self.visible: 415 return list(schema) 416 417 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 418 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
420 def get_column_type( 421 self, 422 table: exp.Table | str, 423 column: exp.Column | str, 424 dialect: DialectType = None, 425 normalize: t.Optional[bool] = None, 426 ) -> exp.DataType: 427 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 428 429 normalized_column_name = self._normalize_name( 430 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 431 ) 432 433 table_schema = self.find(normalized_table, raise_on_missing=False) 434 if table_schema: 435 column_type = table_schema.get(normalized_column_name) 436 437 if isinstance(column_type, exp.DataType): 438 return column_type 439 elif isinstance(column_type, str): 440 return self._to_data_type(column_type, dialect=dialect) 441 442 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
444 def get_udf_type( 445 self, 446 udf: exp.Anonymous | str, 447 dialect: DialectType = None, 448 normalize: t.Optional[bool] = None, 449 ) -> exp.DataType: 450 """ 451 Get the return type of a UDF. 452 453 Args: 454 udf: the UDF expression or string (e.g., "db.my_func()"). 455 dialect: the SQL dialect for parsing string arguments. 456 normalize: whether to normalize identifiers. 457 458 Returns: 459 The return type as a DataType, or UNKNOWN if not found. 460 """ 461 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 462 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 463 464 if resolved_parts is None: 465 return exp.DataType.build("unknown") 466 467 udf_type = nested_get( 468 self.udf_mapping, 469 *zip(resolved_parts, reversed(resolved_parts)), 470 raise_on_missing=False, 471 ) 472 473 if isinstance(udf_type, exp.DataType): 474 return udf_type 475 elif isinstance(udf_type, str): 476 return self._to_data_type(udf_type, dialect=dialect) 477 478 return exp.DataType.build("unknown")
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string (e.g., "db.my_func()").
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
480 def has_column( 481 self, 482 table: exp.Table | str, 483 column: exp.Column | str, 484 dialect: DialectType = None, 485 normalize: t.Optional[bool] = None, 486 ) -> bool: 487 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 488 489 normalized_column_name = self._normalize_name( 490 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 491 ) 492 493 table_schema = self.find(normalized_table, raise_on_missing=False) 494 return normalized_column_name in table_schema if table_schema else False
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
655def normalize_name( 656 identifier: str | exp.Identifier, 657 dialect: DialectType = None, 658 is_table: bool = False, 659 normalize: t.Optional[bool] = True, 660) -> exp.Identifier: 661 if isinstance(identifier, str): 662 identifier = exp.parse_identifier(identifier, dialect=dialect) 663 664 if not normalize: 665 return identifier 666 667 # this is used for normalize_identifier, bigquery has special rules pertaining tables 668 identifier.meta["is_table"] = is_table 669 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
679def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 680 if mapping is None: 681 return {} 682 elif isinstance(mapping, dict): 683 return mapping 684 elif isinstance(mapping, str): 685 col_name_type_strs = [x.strip() for x in mapping.split(",")] 686 return { 687 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 688 for name_type_str in col_name_type_strs 689 } 690 elif isinstance(mapping, list): 691 return {x.strip(): None for x in mapping} 692 693 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
696def flatten_schema( 697 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 698) -> t.List[t.List[str]]: 699 tables = [] 700 keys = keys or [] 701 depth = dict_depth(schema) - 1 if depth is None else depth 702 703 for k, v in schema.items(): 704 if depth == 1 or not isinstance(v, dict): 705 tables.append(keys + [k]) 706 elif depth >= 2: 707 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 708 709 return tables
712def nested_get( 713 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 714) -> t.Optional[t.Any]: 715 """ 716 Get a value for a nested dictionary. 717 718 Args: 719 d: the dictionary to search. 720 *path: tuples of (name, key), where: 721 `key` is the key in the dictionary to get. 722 `name` is a string to use in the error if `key` isn't found. 723 724 Returns: 725 The value or None if it doesn't exist. 726 """ 727 for name, key in path: 728 d = d.get(key) # type: ignore 729 if d is None: 730 if raise_on_missing: 731 name = "table" if name == "this" else name 732 raise ValueError(f"Unknown {name}: {key}") 733 return None 734 735 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
keyis the key in the dictionary to get.nameis a string to use in the error ifkeyisn't found.
Returns:
The value or None if it doesn't exist.
738def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 739 """ 740 In-place set a value for a nested dictionary 741 742 Example: 743 >>> nested_set({}, ["top_key", "second_key"], "value") 744 {'top_key': {'second_key': 'value'}} 745 746 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 747 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 748 749 Args: 750 d: dictionary to update. 751 keys: the keys that makeup the path to `value`. 752 value: the value to set in the dictionary for the given key path. 753 754 Returns: 755 The (possibly) updated dictionary. 756 """ 757 if not keys: 758 return d 759 760 if len(keys) == 1: 761 d[keys[0]] = value 762 return d 763 764 subd = d 765 for key in keys[:-1]: 766 if key not in subd: 767 subd = subd.setdefault(key, {}) 768 else: 769 subd = subd[key] 770 771 subd[keys[-1]] = value 772 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.