sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12from sqlglot.helper import mypyc_attr, trait 13 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 ColumnMapping = t.Union[t.Dict, str, t.List] 19 20 21@trait 22class Schema(abc.ABC): 23 """Abstract base class for database schemas""" 24 25 @property 26 def dialect(self) -> t.Optional[Dialect]: 27 """ 28 Returns None by default. Subclasses that require dialect-specific 29 behavior should override this property. 30 """ 31 return None 32 33 @abc.abstractmethod 34 def add_table( 35 self, 36 table: exp.Table | str, 37 column_mapping: t.Optional[ColumnMapping] = None, 38 dialect: DialectType = None, 39 normalize: t.Optional[bool] = None, 40 match_depth: bool = True, 41 ) -> None: 42 """ 43 Register or update a table. Some implementing classes may require column information to also be provided. 44 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 45 46 Args: 47 table: the `Table` expression instance or string representing the table. 48 column_mapping: a column mapping that describes the structure of the table. 49 dialect: the SQL dialect that will be used to parse `table` if it's a string. 50 normalize: whether to normalize identifiers according to the dialect of interest. 51 match_depth: whether to enforce that the table must match the schema's depth or not. 52 """ 53 54 @abc.abstractmethod 55 def column_names( 56 self, 57 table: exp.Table | str, 58 only_visible: bool = False, 59 dialect: DialectType = None, 60 normalize: t.Optional[bool] = None, 61 ) -> t.Sequence[str]: 62 """ 63 Get the column names for a table. 64 65 Args: 66 table: the `Table` expression instance. 67 only_visible: whether to include invisible columns. 68 dialect: the SQL dialect that will be used to parse `table` if it's a string. 69 normalize: whether to normalize identifiers according to the dialect of interest. 70 71 Returns: 72 The sequence of column names. 73 """ 74 75 @abc.abstractmethod 76 def get_column_type( 77 self, 78 table: exp.Table | str, 79 column: exp.Column | str, 80 dialect: DialectType = None, 81 normalize: t.Optional[bool] = None, 82 ) -> exp.DataType: 83 """ 84 Get the `sqlglot.exp.DataType` type of a column in the schema. 85 86 Args: 87 table: the source table. 88 column: the target column. 89 dialect: the SQL dialect that will be used to parse `table` if it's a string. 90 normalize: whether to normalize identifiers according to the dialect of interest. 91 92 Returns: 93 The resulting column type. 94 """ 95 96 def has_column( 97 self, 98 table: exp.Table | str, 99 column: exp.Column | str, 100 dialect: DialectType = None, 101 normalize: t.Optional[bool] = None, 102 ) -> bool: 103 """ 104 Returns whether `column` appears in `table`'s schema. 105 106 Args: 107 table: the source table. 108 column: the target column. 109 dialect: the SQL dialect that will be used to parse `table` if it's a string. 110 normalize: whether to normalize identifiers according to the dialect of interest. 111 112 Returns: 113 True if the column appears in the schema, False otherwise. 114 """ 115 name = column if isinstance(column, str) else column.name 116 return name in self.column_names(table, dialect=dialect, normalize=normalize) 117 118 def get_udf_type( 119 self, 120 udf: exp.Anonymous | str, 121 dialect: DialectType = None, 122 normalize: t.Optional[bool] = None, 123 ) -> exp.DataType: 124 """ 125 Get the return type of a UDF. 126 127 Args: 128 udf: the UDF expression or string. 129 dialect: the SQL dialect for parsing string arguments. 130 normalize: whether to normalize identifiers. 131 132 Returns: 133 The return type as a DataType, or UNKNOWN if not found. 134 """ 135 return exp.DataType.build("unknown") 136 137 @property 138 @abc.abstractmethod 139 def supported_table_args(self) -> t.Tuple[str, ...]: 140 """ 141 Table arguments this schema support, e.g. `("this", "db", "catalog")` 142 """ 143 144 @property 145 def empty(self) -> bool: 146 """Returns whether the schema is empty.""" 147 return True 148 149 150@mypyc_attr(allow_interpreted_subclasses=True) 151class AbstractMappingSchema: 152 def __init__( 153 self, 154 mapping: t.Optional[t.Dict] = None, 155 udf_mapping: t.Optional[t.Dict] = None, 156 ) -> None: 157 self.mapping = mapping or {} 158 self.mapping_trie = new_trie( 159 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 160 ) 161 162 self.udf_mapping = udf_mapping or {} 163 self.udf_trie = new_trie( 164 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 165 ) 166 167 self._supported_table_args: t.Tuple[str, ...] = tuple() 168 169 @property 170 def empty(self) -> bool: 171 return not self.mapping 172 173 def depth(self) -> int: 174 return dict_depth(self.mapping) 175 176 def udf_depth(self) -> int: 177 return dict_depth(self.udf_mapping) 178 179 @property 180 def supported_table_args(self) -> t.Tuple[str, ...]: 181 if not self._supported_table_args and self.mapping: 182 depth = self.depth() 183 184 if not depth: # None 185 self._supported_table_args = tuple() 186 elif 1 <= depth <= 3: 187 self._supported_table_args = exp.TABLE_PARTS[:depth] 188 else: 189 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 190 191 return self._supported_table_args 192 193 def table_parts(self, table: exp.Table) -> t.List[str]: 194 return [p.name for p in reversed(table.parts)] 195 196 def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: 197 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 198 parent = udf.parent 199 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 200 return list(reversed(parts))[0 : self.udf_depth()] 201 202 def _find_in_trie( 203 self, 204 parts: t.List[str], 205 trie: t.Dict, 206 raise_on_missing: bool, 207 ) -> t.Optional[t.List[str]]: 208 value, trie = in_trie(trie, parts) 209 210 if value == TrieResult.FAILED: 211 return None 212 213 if value == TrieResult.PREFIX: 214 possibilities = flatten_schema(trie) 215 216 if len(possibilities) == 1: 217 parts.extend(possibilities[0]) 218 else: 219 if raise_on_missing: 220 joined_parts = ".".join(parts) 221 message = ", ".join(".".join(p) for p in possibilities) 222 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 223 224 return None 225 226 return parts 227 228 def find( 229 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 230 ) -> t.Optional[t.Any]: 231 """ 232 Returns the schema of a given table. 233 234 Args: 235 table: the target table. 236 raise_on_missing: whether to raise in case the schema is not found. 237 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 238 239 Returns: 240 The schema of the target table. 241 """ 242 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 243 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 244 245 if resolved_parts is None: 246 return None 247 248 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 249 250 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]: 251 """ 252 Returns the return type of a given UDF. 253 254 Args: 255 udf: the target UDF expression. 256 raise_on_missing: whether to raise if the UDF is not found. 257 258 Returns: 259 The return type of the UDF, or None if not found. 260 """ 261 parts = self.udf_parts(udf) 262 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 263 264 if resolved_parts is None: 265 return None 266 267 return nested_get( 268 self.udf_mapping, 269 *zip(resolved_parts, reversed(resolved_parts)), 270 raise_on_missing=raise_on_missing, 271 ) 272 273 def nested_get( 274 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 275 ) -> t.Optional[t.Any]: 276 return nested_get( 277 d or self.mapping, 278 *zip(self.supported_table_args, reversed(parts)), 279 raise_on_missing=raise_on_missing, 280 ) 281 282 283class MappingSchema(AbstractMappingSchema, Schema): 284 """ 285 Schema based on a nested mapping. 286 287 Args: 288 schema: Mapping in one of the following forms: 289 1. {table: {col: type}} 290 2. {db: {table: {col: type}}} 291 3. {catalog: {db: {table: {col: type}}}} 292 4. None - Tables will be added later 293 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 294 are assumed to be visible. The nesting should mirror that of the schema: 295 1. {table: set(*cols)}} 296 2. {db: {table: set(*cols)}}} 297 3. {catalog: {db: {table: set(*cols)}}}} 298 dialect: The dialect to be used for custom type mappings & parsing string arguments. 299 normalize: Whether to normalize identifier names according to the given dialect or not. 300 """ 301 302 def __init__( 303 self, 304 schema: t.Optional[t.Dict] = None, 305 visible: t.Optional[t.Dict] = None, 306 dialect: DialectType = None, 307 normalize: bool = True, 308 udf_mapping: t.Optional[t.Dict] = None, 309 ) -> None: 310 self.visible = {} if visible is None else visible 311 self.normalize = normalize 312 self._dialect = Dialect.get_or_raise(dialect) 313 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 314 self._normalized_table_cache: t.Dict[t.Tuple[exp.Table, DialectType, bool], exp.Table] = {} 315 self._normalized_name_cache: t.Dict[t.Tuple[str, DialectType, bool, bool], str] = {} 316 self._depth = 0 317 schema = {} if schema is None else schema 318 udf_mapping = {} if udf_mapping is None else udf_mapping 319 320 super().__init__( 321 self._normalize(schema) if self.normalize else schema, 322 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 323 ) 324 325 @property 326 def dialect(self) -> Dialect: 327 """Returns the dialect for this mapping schema.""" 328 return self._dialect 329 330 @classmethod 331 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 332 return MappingSchema( 333 schema=mapping_schema.mapping, 334 visible=mapping_schema.visible, 335 dialect=mapping_schema.dialect, 336 normalize=mapping_schema.normalize, 337 udf_mapping=mapping_schema.udf_mapping, 338 ) 339 340 def find( 341 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 342 ) -> t.Optional[t.Any]: 343 schema = super().find( 344 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 345 ) 346 if ensure_data_types and isinstance(schema, dict): 347 schema = { 348 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 349 for col, dtype in schema.items() 350 } 351 352 return schema 353 354 def copy(self, **kwargs) -> MappingSchema: 355 return MappingSchema( 356 **{ # type: ignore 357 "schema": self.mapping.copy(), 358 "visible": self.visible.copy(), 359 "dialect": self.dialect, 360 "normalize": self.normalize, 361 "udf_mapping": self.udf_mapping.copy(), 362 **kwargs, 363 } 364 ) 365 366 def add_table( 367 self, 368 table: exp.Table | str, 369 column_mapping: t.Optional[ColumnMapping] = None, 370 dialect: DialectType = None, 371 normalize: t.Optional[bool] = None, 372 match_depth: bool = True, 373 ) -> None: 374 """ 375 Register or update a table. Updates are only performed if a new column mapping is provided. 376 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 377 378 Args: 379 table: the `Table` expression instance or string representing the table. 380 column_mapping: a column mapping that describes the structure of the table. 381 dialect: the SQL dialect that will be used to parse `table` if it's a string. 382 normalize: whether to normalize identifiers according to the dialect of interest. 383 match_depth: whether to enforce that the table must match the schema's depth or not. 384 """ 385 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 386 387 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 388 raise SchemaError( 389 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 390 f"schema's nesting level: {self.depth()}." 391 ) 392 393 normalized_column_mapping = { 394 self._normalize_name(key, dialect=dialect, normalize=normalize): value 395 for key, value in ensure_column_mapping(column_mapping).items() 396 } 397 398 schema = self.find(normalized_table, raise_on_missing=False) 399 if schema and not normalized_column_mapping: 400 return 401 402 parts = self.table_parts(normalized_table) 403 404 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 405 new_trie([parts], self.mapping_trie) 406 407 def column_names( 408 self, 409 table: exp.Table | str, 410 only_visible: bool = False, 411 dialect: DialectType = None, 412 normalize: t.Optional[bool] = None, 413 ) -> t.List[str]: 414 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 415 416 schema = self.find(normalized_table) 417 if schema is None: 418 return [] 419 420 if not only_visible or not self.visible: 421 return list(schema) 422 423 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 424 return [col for col in schema if col in visible] 425 426 def get_column_type( 427 self, 428 table: exp.Table | str, 429 column: exp.Column | str, 430 dialect: DialectType = None, 431 normalize: t.Optional[bool] = None, 432 ) -> exp.DataType: 433 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 434 435 normalized_column_name = self._normalize_name( 436 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 437 ) 438 439 table_schema = self.find(normalized_table, raise_on_missing=False) 440 if table_schema: 441 column_type = table_schema.get(normalized_column_name) 442 443 if isinstance(column_type, exp.DataType): 444 return column_type 445 elif isinstance(column_type, str): 446 return self._to_data_type(column_type, dialect=dialect) 447 448 return exp.DataType.build("unknown") 449 450 def get_udf_type( 451 self, 452 udf: exp.Anonymous | str, 453 dialect: DialectType = None, 454 normalize: t.Optional[bool] = None, 455 ) -> exp.DataType: 456 """ 457 Get the return type of a UDF. 458 459 Args: 460 udf: the UDF expression or string (e.g., "db.my_func()"). 461 dialect: the SQL dialect for parsing string arguments. 462 normalize: whether to normalize identifiers. 463 464 Returns: 465 The return type as a DataType, or UNKNOWN if not found. 466 """ 467 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 468 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 469 470 if resolved_parts is None: 471 return exp.DataType.build("unknown") 472 473 udf_type = nested_get( 474 self.udf_mapping, 475 *zip(resolved_parts, reversed(resolved_parts)), 476 raise_on_missing=False, 477 ) 478 479 if isinstance(udf_type, exp.DataType): 480 return udf_type 481 elif isinstance(udf_type, str): 482 return self._to_data_type(udf_type, dialect=dialect) 483 484 return exp.DataType.build("unknown") 485 486 def has_column( 487 self, 488 table: exp.Table | str, 489 column: exp.Column | str, 490 dialect: DialectType = None, 491 normalize: t.Optional[bool] = None, 492 ) -> bool: 493 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 494 495 normalized_column_name = self._normalize_name( 496 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 497 ) 498 499 table_schema = self.find(normalized_table, raise_on_missing=False) 500 return normalized_column_name in table_schema if table_schema else False 501 502 def _normalize(self, schema: t.Dict) -> t.Dict: 503 """ 504 Normalizes all identifiers in the schema. 505 506 Args: 507 schema: the schema to normalize. 508 509 Returns: 510 The normalized schema mapping. 511 """ 512 normalized_mapping: t.Dict = {} 513 flattened_schema = flatten_schema(schema) 514 error_msg = "Table {} must match the schema's nesting level: {}." 515 516 for keys in flattened_schema: 517 columns = nested_get(schema, *zip(keys, keys)) 518 519 if not isinstance(columns, dict): 520 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 521 if not columns: 522 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 523 if isinstance(first(columns.values()), dict): 524 raise SchemaError( 525 error_msg.format( 526 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 527 ), 528 ) 529 530 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 531 for column_name, column_type in columns.items(): 532 nested_set( 533 normalized_mapping, 534 normalized_keys + [self._normalize_name(column_name)], 535 column_type, 536 ) 537 538 return normalized_mapping 539 540 def _normalize_udfs(self, udfs: t.Dict) -> t.Dict: 541 """ 542 Normalizes all identifiers in the UDF mapping. 543 544 Args: 545 udfs: the UDF mapping to normalize. 546 547 Returns: 548 The normalized UDF mapping. 549 """ 550 normalized_mapping: t.Dict = {} 551 552 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 553 udf_type = nested_get(udfs, *zip(keys, keys)) 554 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 555 nested_set(normalized_mapping, normalized_keys, udf_type) 556 557 return normalized_mapping 558 559 def _normalize_udf( 560 self, 561 udf: exp.Anonymous | str, 562 dialect: DialectType = None, 563 normalize: t.Optional[bool] = None, 564 ) -> t.List[str]: 565 """ 566 Extract and normalize UDF parts for lookup. 567 568 Args: 569 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 570 dialect: the SQL dialect for parsing. 571 normalize: whether to normalize identifiers. 572 573 Returns: 574 A list of normalized UDF parts (reversed for trie lookup). 575 """ 576 dialect = dialect or self.dialect 577 normalize = self.normalize if normalize is None else normalize 578 579 if isinstance(udf, str): 580 parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect) 581 582 if isinstance(parsed, exp.Anonymous): 583 udf = parsed 584 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 585 udf = parsed.expression 586 else: 587 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 588 parts = self.udf_parts(udf) 589 590 if normalize: 591 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 592 593 return parts 594 595 def _normalize_table( 596 self, 597 table: exp.Table | str, 598 dialect: DialectType = None, 599 normalize: t.Optional[bool] = None, 600 ) -> exp.Table: 601 dialect = dialect or self.dialect 602 normalize = self.normalize if normalize is None else normalize 603 604 # Cache normalized tables by object id for exp.Table inputs 605 # This is effective when the same Table object is looked up multiple times 606 if isinstance(table, exp.Table) and ( 607 cached := self._normalized_table_cache.get((table, dialect, normalize)) 608 ): 609 return cached 610 611 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 612 613 if normalize: 614 for part in normalized_table.parts: 615 if isinstance(part, exp.Identifier): 616 part.replace( 617 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 618 ) 619 620 self._normalized_table_cache[(t.cast(exp.Table, table), dialect, normalize)] = ( 621 normalized_table 622 ) 623 return normalized_table 624 625 def _normalize_name( 626 self, 627 name: str | exp.Identifier, 628 dialect: DialectType = None, 629 is_table: bool = False, 630 normalize: t.Optional[bool] = None, 631 ) -> str: 632 normalize = self.normalize if normalize is None else normalize 633 634 dialect = dialect or self.dialect 635 name_str = name if isinstance(name, str) else name.name 636 cache_key = (name_str, dialect, is_table, normalize) 637 638 if cached := self._normalized_name_cache.get(cache_key): 639 return cached 640 641 result = normalize_name( 642 name, 643 dialect=dialect, 644 is_table=is_table, 645 normalize=normalize, 646 ).name 647 648 self._normalized_name_cache[cache_key] = result 649 return result 650 651 def depth(self) -> int: 652 if not self.empty and not self._depth: 653 # The columns themselves are a mapping, but we don't want to include those 654 self._depth = super().depth() - 1 655 return self._depth 656 657 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 658 """ 659 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 660 661 Args: 662 schema_type: the type we want to convert. 663 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 664 665 Returns: 666 The resulting expression type. 667 """ 668 if schema_type not in self._type_mapping_cache: 669 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 670 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 671 672 try: 673 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 674 expression.transform(dialect.normalize_identifier, copy=False) 675 self._type_mapping_cache[schema_type] = expression 676 except AttributeError: 677 in_dialect = f" in dialect {dialect}" if dialect else "" 678 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 679 680 return self._type_mapping_cache[schema_type] 681 682 683def normalize_name( 684 identifier: str | exp.Identifier, 685 dialect: DialectType = None, 686 is_table: bool = False, 687 normalize: t.Optional[bool] = True, 688) -> exp.Identifier: 689 if isinstance(identifier, str): 690 identifier = exp.parse_identifier(identifier, dialect=dialect) 691 692 if not normalize: 693 return identifier 694 695 # this is used for normalize_identifier, bigquery has special rules pertaining tables 696 identifier.meta["is_table"] = is_table 697 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 698 699 700def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 701 if isinstance(schema, Schema): 702 return schema 703 704 return MappingSchema(schema, **kwargs) 705 706 707def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 708 if mapping is None: 709 return {} 710 elif isinstance(mapping, dict): 711 return mapping 712 elif isinstance(mapping, str): 713 col_name_type_strs = [x.strip() for x in mapping.split(",")] 714 return { 715 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 716 for name_type_str in col_name_type_strs 717 } 718 elif isinstance(mapping, list): 719 return {x.strip(): None for x in mapping} 720 721 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 722 723 724def flatten_schema( 725 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 726) -> t.List[t.List[str]]: 727 tables = [] 728 keys = keys or [] 729 depth = dict_depth(schema) - 1 if depth is None else depth 730 731 for k, v in schema.items(): 732 if depth == 1 or not isinstance(v, dict): 733 tables.append(keys + [k]) 734 elif depth >= 2: 735 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 736 737 return tables 738 739 740def nested_get( 741 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 742) -> t.Optional[t.Any]: 743 """ 744 Get a value for a nested dictionary. 745 746 Args: 747 d: the dictionary to search. 748 *path: tuples of (name, key), where: 749 `key` is the key in the dictionary to get. 750 `name` is a string to use in the error if `key` isn't found. 751 752 Returns: 753 The value or None if it doesn't exist. 754 """ 755 result: t.Any = d 756 for name, key in path: 757 result = result.get(key) 758 if result is None: 759 if raise_on_missing: 760 name = "table" if name == "this" else name 761 raise ValueError(f"Unknown {name}: {key}") 762 return None 763 764 return result 765 766 767def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 768 """ 769 In-place set a value for a nested dictionary 770 771 Example: 772 >>> nested_set({}, ["top_key", "second_key"], "value") 773 {'top_key': {'second_key': 'value'}} 774 775 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 776 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 777 778 Args: 779 d: dictionary to update. 780 keys: the keys that makeup the path to `value`. 781 value: the value to set in the dictionary for the given key path. 782 783 Returns: 784 The (possibly) updated dictionary. 785 """ 786 if not keys: 787 return d 788 789 if len(keys) == 1: 790 d[keys[0]] = value 791 return d 792 793 subd = d 794 for key in keys[:-1]: 795 if key not in subd: 796 subd = subd.setdefault(key, {}) 797 else: 798 subd = subd[key] 799 800 subd[keys[-1]] = value 801 return d
22@trait 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 @property 27 def dialect(self) -> t.Optional[Dialect]: 28 """ 29 Returns None by default. Subclasses that require dialect-specific 30 behavior should override this property. 31 """ 32 return None 33 34 @abc.abstractmethod 35 def add_table( 36 self, 37 table: exp.Table | str, 38 column_mapping: t.Optional[ColumnMapping] = None, 39 dialect: DialectType = None, 40 normalize: t.Optional[bool] = None, 41 match_depth: bool = True, 42 ) -> None: 43 """ 44 Register or update a table. Some implementing classes may require column information to also be provided. 45 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 46 47 Args: 48 table: the `Table` expression instance or string representing the table. 49 column_mapping: a column mapping that describes the structure of the table. 50 dialect: the SQL dialect that will be used to parse `table` if it's a string. 51 normalize: whether to normalize identifiers according to the dialect of interest. 52 match_depth: whether to enforce that the table must match the schema's depth or not. 53 """ 54 55 @abc.abstractmethod 56 def column_names( 57 self, 58 table: exp.Table | str, 59 only_visible: bool = False, 60 dialect: DialectType = None, 61 normalize: t.Optional[bool] = None, 62 ) -> t.Sequence[str]: 63 """ 64 Get the column names for a table. 65 66 Args: 67 table: the `Table` expression instance. 68 only_visible: whether to include invisible columns. 69 dialect: the SQL dialect that will be used to parse `table` if it's a string. 70 normalize: whether to normalize identifiers according to the dialect of interest. 71 72 Returns: 73 The sequence of column names. 74 """ 75 76 @abc.abstractmethod 77 def get_column_type( 78 self, 79 table: exp.Table | str, 80 column: exp.Column | str, 81 dialect: DialectType = None, 82 normalize: t.Optional[bool] = None, 83 ) -> exp.DataType: 84 """ 85 Get the `sqlglot.exp.DataType` type of a column in the schema. 86 87 Args: 88 table: the source table. 89 column: the target column. 90 dialect: the SQL dialect that will be used to parse `table` if it's a string. 91 normalize: whether to normalize identifiers according to the dialect of interest. 92 93 Returns: 94 The resulting column type. 95 """ 96 97 def has_column( 98 self, 99 table: exp.Table | str, 100 column: exp.Column | str, 101 dialect: DialectType = None, 102 normalize: t.Optional[bool] = None, 103 ) -> bool: 104 """ 105 Returns whether `column` appears in `table`'s schema. 106 107 Args: 108 table: the source table. 109 column: the target column. 110 dialect: the SQL dialect that will be used to parse `table` if it's a string. 111 normalize: whether to normalize identifiers according to the dialect of interest. 112 113 Returns: 114 True if the column appears in the schema, False otherwise. 115 """ 116 name = column if isinstance(column, str) else column.name 117 return name in self.column_names(table, dialect=dialect, normalize=normalize) 118 119 def get_udf_type( 120 self, 121 udf: exp.Anonymous | str, 122 dialect: DialectType = None, 123 normalize: t.Optional[bool] = None, 124 ) -> exp.DataType: 125 """ 126 Get the return type of a UDF. 127 128 Args: 129 udf: the UDF expression or string. 130 dialect: the SQL dialect for parsing string arguments. 131 normalize: whether to normalize identifiers. 132 133 Returns: 134 The return type as a DataType, or UNKNOWN if not found. 135 """ 136 return exp.DataType.build("unknown") 137 138 @property 139 @abc.abstractmethod 140 def supported_table_args(self) -> t.Tuple[str, ...]: 141 """ 142 Table arguments this schema support, e.g. `("this", "db", "catalog")` 143 """ 144 145 @property 146 def empty(self) -> bool: 147 """Returns whether the schema is empty.""" 148 return True
Abstract base class for database schemas
26 @property 27 def dialect(self) -> t.Optional[Dialect]: 28 """ 29 Returns None by default. Subclasses that require dialect-specific 30 behavior should override this property. 31 """ 32 return None
Returns None by default. Subclasses that require dialect-specific behavior should override this property.
34 @abc.abstractmethod 35 def add_table( 36 self, 37 table: exp.Table | str, 38 column_mapping: t.Optional[ColumnMapping] = None, 39 dialect: DialectType = None, 40 normalize: t.Optional[bool] = None, 41 match_depth: bool = True, 42 ) -> None: 43 """ 44 Register or update a table. Some implementing classes may require column information to also be provided. 45 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 46 47 Args: 48 table: the `Table` expression instance or string representing the table. 49 column_mapping: a column mapping that describes the structure of the table. 50 dialect: the SQL dialect that will be used to parse `table` if it's a string. 51 normalize: whether to normalize identifiers according to the dialect of interest. 52 match_depth: whether to enforce that the table must match the schema's depth or not. 53 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
55 @abc.abstractmethod 56 def column_names( 57 self, 58 table: exp.Table | str, 59 only_visible: bool = False, 60 dialect: DialectType = None, 61 normalize: t.Optional[bool] = None, 62 ) -> t.Sequence[str]: 63 """ 64 Get the column names for a table. 65 66 Args: 67 table: the `Table` expression instance. 68 only_visible: whether to include invisible columns. 69 dialect: the SQL dialect that will be used to parse `table` if it's a string. 70 normalize: whether to normalize identifiers according to the dialect of interest. 71 72 Returns: 73 The sequence of column names. 74 """
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
76 @abc.abstractmethod 77 def get_column_type( 78 self, 79 table: exp.Table | str, 80 column: exp.Column | str, 81 dialect: DialectType = None, 82 normalize: t.Optional[bool] = None, 83 ) -> exp.DataType: 84 """ 85 Get the `sqlglot.exp.DataType` type of a column in the schema. 86 87 Args: 88 table: the source table. 89 column: the target column. 90 dialect: the SQL dialect that will be used to parse `table` if it's a string. 91 normalize: whether to normalize identifiers according to the dialect of interest. 92 93 Returns: 94 The resulting column type. 95 """
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
97 def has_column( 98 self, 99 table: exp.Table | str, 100 column: exp.Column | str, 101 dialect: DialectType = None, 102 normalize: t.Optional[bool] = None, 103 ) -> bool: 104 """ 105 Returns whether `column` appears in `table`'s schema. 106 107 Args: 108 table: the source table. 109 column: the target column. 110 dialect: the SQL dialect that will be used to parse `table` if it's a string. 111 normalize: whether to normalize identifiers according to the dialect of interest. 112 113 Returns: 114 True if the column appears in the schema, False otherwise. 115 """ 116 name = column if isinstance(column, str) else column.name 117 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
119 def get_udf_type( 120 self, 121 udf: exp.Anonymous | str, 122 dialect: DialectType = None, 123 normalize: t.Optional[bool] = None, 124 ) -> exp.DataType: 125 """ 126 Get the return type of a UDF. 127 128 Args: 129 udf: the UDF expression or string. 130 dialect: the SQL dialect for parsing string arguments. 131 normalize: whether to normalize identifiers. 132 133 Returns: 134 The return type as a DataType, or UNKNOWN if not found. 135 """ 136 return exp.DataType.build("unknown")
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string.
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
151@mypyc_attr(allow_interpreted_subclasses=True) 152class AbstractMappingSchema: 153 def __init__( 154 self, 155 mapping: t.Optional[t.Dict] = None, 156 udf_mapping: t.Optional[t.Dict] = None, 157 ) -> None: 158 self.mapping = mapping or {} 159 self.mapping_trie = new_trie( 160 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 161 ) 162 163 self.udf_mapping = udf_mapping or {} 164 self.udf_trie = new_trie( 165 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 166 ) 167 168 self._supported_table_args: t.Tuple[str, ...] = tuple() 169 170 @property 171 def empty(self) -> bool: 172 return not self.mapping 173 174 def depth(self) -> int: 175 return dict_depth(self.mapping) 176 177 def udf_depth(self) -> int: 178 return dict_depth(self.udf_mapping) 179 180 @property 181 def supported_table_args(self) -> t.Tuple[str, ...]: 182 if not self._supported_table_args and self.mapping: 183 depth = self.depth() 184 185 if not depth: # None 186 self._supported_table_args = tuple() 187 elif 1 <= depth <= 3: 188 self._supported_table_args = exp.TABLE_PARTS[:depth] 189 else: 190 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 191 192 return self._supported_table_args 193 194 def table_parts(self, table: exp.Table) -> t.List[str]: 195 return [p.name for p in reversed(table.parts)] 196 197 def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: 198 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 199 parent = udf.parent 200 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 201 return list(reversed(parts))[0 : self.udf_depth()] 202 203 def _find_in_trie( 204 self, 205 parts: t.List[str], 206 trie: t.Dict, 207 raise_on_missing: bool, 208 ) -> t.Optional[t.List[str]]: 209 value, trie = in_trie(trie, parts) 210 211 if value == TrieResult.FAILED: 212 return None 213 214 if value == TrieResult.PREFIX: 215 possibilities = flatten_schema(trie) 216 217 if len(possibilities) == 1: 218 parts.extend(possibilities[0]) 219 else: 220 if raise_on_missing: 221 joined_parts = ".".join(parts) 222 message = ", ".join(".".join(p) for p in possibilities) 223 raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") 224 225 return None 226 227 return parts 228 229 def find( 230 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 231 ) -> t.Optional[t.Any]: 232 """ 233 Returns the schema of a given table. 234 235 Args: 236 table: the target table. 237 raise_on_missing: whether to raise in case the schema is not found. 238 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 239 240 Returns: 241 The schema of the target table. 242 """ 243 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 244 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 245 246 if resolved_parts is None: 247 return None 248 249 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) 250 251 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]: 252 """ 253 Returns the return type of a given UDF. 254 255 Args: 256 udf: the target UDF expression. 257 raise_on_missing: whether to raise if the UDF is not found. 258 259 Returns: 260 The return type of the UDF, or None if not found. 261 """ 262 parts = self.udf_parts(udf) 263 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 264 265 if resolved_parts is None: 266 return None 267 268 return nested_get( 269 self.udf_mapping, 270 *zip(resolved_parts, reversed(resolved_parts)), 271 raise_on_missing=raise_on_missing, 272 ) 273 274 def nested_get( 275 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 276 ) -> t.Optional[t.Any]: 277 return nested_get( 278 d or self.mapping, 279 *zip(self.supported_table_args, reversed(parts)), 280 raise_on_missing=raise_on_missing, 281 )
153 def __init__( 154 self, 155 mapping: t.Optional[t.Dict] = None, 156 udf_mapping: t.Optional[t.Dict] = None, 157 ) -> None: 158 self.mapping = mapping or {} 159 self.mapping_trie = new_trie( 160 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 161 ) 162 163 self.udf_mapping = udf_mapping or {} 164 self.udf_trie = new_trie( 165 tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) 166 ) 167 168 self._supported_table_args: t.Tuple[str, ...] = tuple()
180 @property 181 def supported_table_args(self) -> t.Tuple[str, ...]: 182 if not self._supported_table_args and self.mapping: 183 depth = self.depth() 184 185 if not depth: # None 186 self._supported_table_args = tuple() 187 elif 1 <= depth <= 3: 188 self._supported_table_args = exp.TABLE_PARTS[:depth] 189 else: 190 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 191 192 return self._supported_table_args
197 def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: 198 # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) 199 parent = udf.parent 200 parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] 201 return list(reversed(parts))[0 : self.udf_depth()]
229 def find( 230 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 231 ) -> t.Optional[t.Any]: 232 """ 233 Returns the schema of a given table. 234 235 Args: 236 table: the target table. 237 raise_on_missing: whether to raise in case the schema is not found. 238 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 239 240 Returns: 241 The schema of the target table. 242 """ 243 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 244 resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) 245 246 if resolved_parts is None: 247 return None 248 249 return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
251 def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]: 252 """ 253 Returns the return type of a given UDF. 254 255 Args: 256 udf: the target UDF expression. 257 raise_on_missing: whether to raise if the UDF is not found. 258 259 Returns: 260 The return type of the UDF, or None if not found. 261 """ 262 parts = self.udf_parts(udf) 263 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) 264 265 if resolved_parts is None: 266 return None 267 268 return nested_get( 269 self.udf_mapping, 270 *zip(resolved_parts, reversed(resolved_parts)), 271 raise_on_missing=raise_on_missing, 272 )
Returns the return type of a given UDF.
Arguments:
- udf: the target UDF expression.
- raise_on_missing: whether to raise if the UDF is not found.
Returns:
The return type of the UDF, or None if not found.
284class MappingSchema(AbstractMappingSchema, Schema): 285 """ 286 Schema based on a nested mapping. 287 288 Args: 289 schema: Mapping in one of the following forms: 290 1. {table: {col: type}} 291 2. {db: {table: {col: type}}} 292 3. {catalog: {db: {table: {col: type}}}} 293 4. None - Tables will be added later 294 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 295 are assumed to be visible. The nesting should mirror that of the schema: 296 1. {table: set(*cols)}} 297 2. {db: {table: set(*cols)}}} 298 3. {catalog: {db: {table: set(*cols)}}}} 299 dialect: The dialect to be used for custom type mappings & parsing string arguments. 300 normalize: Whether to normalize identifier names according to the given dialect or not. 301 """ 302 303 def __init__( 304 self, 305 schema: t.Optional[t.Dict] = None, 306 visible: t.Optional[t.Dict] = None, 307 dialect: DialectType = None, 308 normalize: bool = True, 309 udf_mapping: t.Optional[t.Dict] = None, 310 ) -> None: 311 self.visible = {} if visible is None else visible 312 self.normalize = normalize 313 self._dialect = Dialect.get_or_raise(dialect) 314 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 315 self._normalized_table_cache: t.Dict[t.Tuple[exp.Table, DialectType, bool], exp.Table] = {} 316 self._normalized_name_cache: t.Dict[t.Tuple[str, DialectType, bool, bool], str] = {} 317 self._depth = 0 318 schema = {} if schema is None else schema 319 udf_mapping = {} if udf_mapping is None else udf_mapping 320 321 super().__init__( 322 self._normalize(schema) if self.normalize else schema, 323 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 324 ) 325 326 @property 327 def dialect(self) -> Dialect: 328 """Returns the dialect for this mapping schema.""" 329 return self._dialect 330 331 @classmethod 332 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 333 return MappingSchema( 334 schema=mapping_schema.mapping, 335 visible=mapping_schema.visible, 336 dialect=mapping_schema.dialect, 337 normalize=mapping_schema.normalize, 338 udf_mapping=mapping_schema.udf_mapping, 339 ) 340 341 def find( 342 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 343 ) -> t.Optional[t.Any]: 344 schema = super().find( 345 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 346 ) 347 if ensure_data_types and isinstance(schema, dict): 348 schema = { 349 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 350 for col, dtype in schema.items() 351 } 352 353 return schema 354 355 def copy(self, **kwargs) -> MappingSchema: 356 return MappingSchema( 357 **{ # type: ignore 358 "schema": self.mapping.copy(), 359 "visible": self.visible.copy(), 360 "dialect": self.dialect, 361 "normalize": self.normalize, 362 "udf_mapping": self.udf_mapping.copy(), 363 **kwargs, 364 } 365 ) 366 367 def add_table( 368 self, 369 table: exp.Table | str, 370 column_mapping: t.Optional[ColumnMapping] = None, 371 dialect: DialectType = None, 372 normalize: t.Optional[bool] = None, 373 match_depth: bool = True, 374 ) -> None: 375 """ 376 Register or update a table. Updates are only performed if a new column mapping is provided. 377 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 378 379 Args: 380 table: the `Table` expression instance or string representing the table. 381 column_mapping: a column mapping that describes the structure of the table. 382 dialect: the SQL dialect that will be used to parse `table` if it's a string. 383 normalize: whether to normalize identifiers according to the dialect of interest. 384 match_depth: whether to enforce that the table must match the schema's depth or not. 385 """ 386 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 387 388 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 389 raise SchemaError( 390 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 391 f"schema's nesting level: {self.depth()}." 392 ) 393 394 normalized_column_mapping = { 395 self._normalize_name(key, dialect=dialect, normalize=normalize): value 396 for key, value in ensure_column_mapping(column_mapping).items() 397 } 398 399 schema = self.find(normalized_table, raise_on_missing=False) 400 if schema and not normalized_column_mapping: 401 return 402 403 parts = self.table_parts(normalized_table) 404 405 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 406 new_trie([parts], self.mapping_trie) 407 408 def column_names( 409 self, 410 table: exp.Table | str, 411 only_visible: bool = False, 412 dialect: DialectType = None, 413 normalize: t.Optional[bool] = None, 414 ) -> t.List[str]: 415 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 416 417 schema = self.find(normalized_table) 418 if schema is None: 419 return [] 420 421 if not only_visible or not self.visible: 422 return list(schema) 423 424 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 425 return [col for col in schema if col in visible] 426 427 def get_column_type( 428 self, 429 table: exp.Table | str, 430 column: exp.Column | str, 431 dialect: DialectType = None, 432 normalize: t.Optional[bool] = None, 433 ) -> exp.DataType: 434 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 435 436 normalized_column_name = self._normalize_name( 437 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 438 ) 439 440 table_schema = self.find(normalized_table, raise_on_missing=False) 441 if table_schema: 442 column_type = table_schema.get(normalized_column_name) 443 444 if isinstance(column_type, exp.DataType): 445 return column_type 446 elif isinstance(column_type, str): 447 return self._to_data_type(column_type, dialect=dialect) 448 449 return exp.DataType.build("unknown") 450 451 def get_udf_type( 452 self, 453 udf: exp.Anonymous | str, 454 dialect: DialectType = None, 455 normalize: t.Optional[bool] = None, 456 ) -> exp.DataType: 457 """ 458 Get the return type of a UDF. 459 460 Args: 461 udf: the UDF expression or string (e.g., "db.my_func()"). 462 dialect: the SQL dialect for parsing string arguments. 463 normalize: whether to normalize identifiers. 464 465 Returns: 466 The return type as a DataType, or UNKNOWN if not found. 467 """ 468 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 469 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 470 471 if resolved_parts is None: 472 return exp.DataType.build("unknown") 473 474 udf_type = nested_get( 475 self.udf_mapping, 476 *zip(resolved_parts, reversed(resolved_parts)), 477 raise_on_missing=False, 478 ) 479 480 if isinstance(udf_type, exp.DataType): 481 return udf_type 482 elif isinstance(udf_type, str): 483 return self._to_data_type(udf_type, dialect=dialect) 484 485 return exp.DataType.build("unknown") 486 487 def has_column( 488 self, 489 table: exp.Table | str, 490 column: exp.Column | str, 491 dialect: DialectType = None, 492 normalize: t.Optional[bool] = None, 493 ) -> bool: 494 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 495 496 normalized_column_name = self._normalize_name( 497 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 498 ) 499 500 table_schema = self.find(normalized_table, raise_on_missing=False) 501 return normalized_column_name in table_schema if table_schema else False 502 503 def _normalize(self, schema: t.Dict) -> t.Dict: 504 """ 505 Normalizes all identifiers in the schema. 506 507 Args: 508 schema: the schema to normalize. 509 510 Returns: 511 The normalized schema mapping. 512 """ 513 normalized_mapping: t.Dict = {} 514 flattened_schema = flatten_schema(schema) 515 error_msg = "Table {} must match the schema's nesting level: {}." 516 517 for keys in flattened_schema: 518 columns = nested_get(schema, *zip(keys, keys)) 519 520 if not isinstance(columns, dict): 521 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 522 if not columns: 523 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 524 if isinstance(first(columns.values()), dict): 525 raise SchemaError( 526 error_msg.format( 527 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 528 ), 529 ) 530 531 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 532 for column_name, column_type in columns.items(): 533 nested_set( 534 normalized_mapping, 535 normalized_keys + [self._normalize_name(column_name)], 536 column_type, 537 ) 538 539 return normalized_mapping 540 541 def _normalize_udfs(self, udfs: t.Dict) -> t.Dict: 542 """ 543 Normalizes all identifiers in the UDF mapping. 544 545 Args: 546 udfs: the UDF mapping to normalize. 547 548 Returns: 549 The normalized UDF mapping. 550 """ 551 normalized_mapping: t.Dict = {} 552 553 for keys in flatten_schema(udfs, depth=dict_depth(udfs)): 554 udf_type = nested_get(udfs, *zip(keys, keys)) 555 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 556 nested_set(normalized_mapping, normalized_keys, udf_type) 557 558 return normalized_mapping 559 560 def _normalize_udf( 561 self, 562 udf: exp.Anonymous | str, 563 dialect: DialectType = None, 564 normalize: t.Optional[bool] = None, 565 ) -> t.List[str]: 566 """ 567 Extract and normalize UDF parts for lookup. 568 569 Args: 570 udf: the UDF expression or qualified string (e.g., "db.my_func()"). 571 dialect: the SQL dialect for parsing. 572 normalize: whether to normalize identifiers. 573 574 Returns: 575 A list of normalized UDF parts (reversed for trie lookup). 576 """ 577 dialect = dialect or self.dialect 578 normalize = self.normalize if normalize is None else normalize 579 580 if isinstance(udf, str): 581 parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect) 582 583 if isinstance(parsed, exp.Anonymous): 584 udf = parsed 585 elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): 586 udf = parsed.expression 587 else: 588 raise SchemaError(f"Unable to parse UDF from: {udf!r}") 589 parts = self.udf_parts(udf) 590 591 if normalize: 592 parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] 593 594 return parts 595 596 def _normalize_table( 597 self, 598 table: exp.Table | str, 599 dialect: DialectType = None, 600 normalize: t.Optional[bool] = None, 601 ) -> exp.Table: 602 dialect = dialect or self.dialect 603 normalize = self.normalize if normalize is None else normalize 604 605 # Cache normalized tables by object id for exp.Table inputs 606 # This is effective when the same Table object is looked up multiple times 607 if isinstance(table, exp.Table) and ( 608 cached := self._normalized_table_cache.get((table, dialect, normalize)) 609 ): 610 return cached 611 612 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 613 614 if normalize: 615 for part in normalized_table.parts: 616 if isinstance(part, exp.Identifier): 617 part.replace( 618 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 619 ) 620 621 self._normalized_table_cache[(t.cast(exp.Table, table), dialect, normalize)] = ( 622 normalized_table 623 ) 624 return normalized_table 625 626 def _normalize_name( 627 self, 628 name: str | exp.Identifier, 629 dialect: DialectType = None, 630 is_table: bool = False, 631 normalize: t.Optional[bool] = None, 632 ) -> str: 633 normalize = self.normalize if normalize is None else normalize 634 635 dialect = dialect or self.dialect 636 name_str = name if isinstance(name, str) else name.name 637 cache_key = (name_str, dialect, is_table, normalize) 638 639 if cached := self._normalized_name_cache.get(cache_key): 640 return cached 641 642 result = normalize_name( 643 name, 644 dialect=dialect, 645 is_table=is_table, 646 normalize=normalize, 647 ).name 648 649 self._normalized_name_cache[cache_key] = result 650 return result 651 652 def depth(self) -> int: 653 if not self.empty and not self._depth: 654 # The columns themselves are a mapping, but we don't want to include those 655 self._depth = super().depth() - 1 656 return self._depth 657 658 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 659 """ 660 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 661 662 Args: 663 schema_type: the type we want to convert. 664 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 665 666 Returns: 667 The resulting expression type. 668 """ 669 if schema_type not in self._type_mapping_cache: 670 dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect 671 udt = dialect.SUPPORTS_USER_DEFINED_TYPES 672 673 try: 674 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 675 expression.transform(dialect.normalize_identifier, copy=False) 676 self._type_mapping_cache[schema_type] = expression 677 except AttributeError: 678 in_dialect = f" in dialect {dialect}" if dialect else "" 679 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 680 681 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
303 def __init__( 304 self, 305 schema: t.Optional[t.Dict] = None, 306 visible: t.Optional[t.Dict] = None, 307 dialect: DialectType = None, 308 normalize: bool = True, 309 udf_mapping: t.Optional[t.Dict] = None, 310 ) -> None: 311 self.visible = {} if visible is None else visible 312 self.normalize = normalize 313 self._dialect = Dialect.get_or_raise(dialect) 314 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 315 self._normalized_table_cache: t.Dict[t.Tuple[exp.Table, DialectType, bool], exp.Table] = {} 316 self._normalized_name_cache: t.Dict[t.Tuple[str, DialectType, bool, bool], str] = {} 317 self._depth = 0 318 schema = {} if schema is None else schema 319 udf_mapping = {} if udf_mapping is None else udf_mapping 320 321 super().__init__( 322 self._normalize(schema) if self.normalize else schema, 323 self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, 324 )
326 @property 327 def dialect(self) -> Dialect: 328 """Returns the dialect for this mapping schema.""" 329 return self._dialect
Returns the dialect for this mapping schema.
331 @classmethod 332 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 333 return MappingSchema( 334 schema=mapping_schema.mapping, 335 visible=mapping_schema.visible, 336 dialect=mapping_schema.dialect, 337 normalize=mapping_schema.normalize, 338 udf_mapping=mapping_schema.udf_mapping, 339 )
341 def find( 342 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 343 ) -> t.Optional[t.Any]: 344 schema = super().find( 345 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 346 ) 347 if ensure_data_types and isinstance(schema, dict): 348 schema = { 349 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 350 for col, dtype in schema.items() 351 } 352 353 return schema
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
strtypes to theirDataTypeequivalents.
Returns:
The schema of the target table.
367 def add_table( 368 self, 369 table: exp.Table | str, 370 column_mapping: t.Optional[ColumnMapping] = None, 371 dialect: DialectType = None, 372 normalize: t.Optional[bool] = None, 373 match_depth: bool = True, 374 ) -> None: 375 """ 376 Register or update a table. Updates are only performed if a new column mapping is provided. 377 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 378 379 Args: 380 table: the `Table` expression instance or string representing the table. 381 column_mapping: a column mapping that describes the structure of the table. 382 dialect: the SQL dialect that will be used to parse `table` if it's a string. 383 normalize: whether to normalize identifiers according to the dialect of interest. 384 match_depth: whether to enforce that the table must match the schema's depth or not. 385 """ 386 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 387 388 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 389 raise SchemaError( 390 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 391 f"schema's nesting level: {self.depth()}." 392 ) 393 394 normalized_column_mapping = { 395 self._normalize_name(key, dialect=dialect, normalize=normalize): value 396 for key, value in ensure_column_mapping(column_mapping).items() 397 } 398 399 schema = self.find(normalized_table, raise_on_missing=False) 400 if schema and not normalized_column_mapping: 401 return 402 403 parts = self.table_parts(normalized_table) 404 405 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 406 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Tableexpression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
408 def column_names( 409 self, 410 table: exp.Table | str, 411 only_visible: bool = False, 412 dialect: DialectType = None, 413 normalize: t.Optional[bool] = None, 414 ) -> t.List[str]: 415 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 416 417 schema = self.find(normalized_table) 418 if schema is None: 419 return [] 420 421 if not only_visible or not self.visible: 422 return list(schema) 423 424 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 425 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Tableexpression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
427 def get_column_type( 428 self, 429 table: exp.Table | str, 430 column: exp.Column | str, 431 dialect: DialectType = None, 432 normalize: t.Optional[bool] = None, 433 ) -> exp.DataType: 434 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 435 436 normalized_column_name = self._normalize_name( 437 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 438 ) 439 440 table_schema = self.find(normalized_table, raise_on_missing=False) 441 if table_schema: 442 column_type = table_schema.get(normalized_column_name) 443 444 if isinstance(column_type, exp.DataType): 445 return column_type 446 elif isinstance(column_type, str): 447 return self._to_data_type(column_type, dialect=dialect) 448 449 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
451 def get_udf_type( 452 self, 453 udf: exp.Anonymous | str, 454 dialect: DialectType = None, 455 normalize: t.Optional[bool] = None, 456 ) -> exp.DataType: 457 """ 458 Get the return type of a UDF. 459 460 Args: 461 udf: the UDF expression or string (e.g., "db.my_func()"). 462 dialect: the SQL dialect for parsing string arguments. 463 normalize: whether to normalize identifiers. 464 465 Returns: 466 The return type as a DataType, or UNKNOWN if not found. 467 """ 468 parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) 469 resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) 470 471 if resolved_parts is None: 472 return exp.DataType.build("unknown") 473 474 udf_type = nested_get( 475 self.udf_mapping, 476 *zip(resolved_parts, reversed(resolved_parts)), 477 raise_on_missing=False, 478 ) 479 480 if isinstance(udf_type, exp.DataType): 481 return udf_type 482 elif isinstance(udf_type, str): 483 return self._to_data_type(udf_type, dialect=dialect) 484 485 return exp.DataType.build("unknown")
Get the return type of a UDF.
Arguments:
- udf: the UDF expression or string (e.g., "db.my_func()").
- dialect: the SQL dialect for parsing string arguments.
- normalize: whether to normalize identifiers.
Returns:
The return type as a DataType, or UNKNOWN if not found.
487 def has_column( 488 self, 489 table: exp.Table | str, 490 column: exp.Column | str, 491 dialect: DialectType = None, 492 normalize: t.Optional[bool] = None, 493 ) -> bool: 494 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 495 496 normalized_column_name = self._normalize_name( 497 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 498 ) 499 500 table_schema = self.find(normalized_table, raise_on_missing=False) 501 return normalized_column_name in table_schema if table_schema else False
Returns whether column appears in table's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
tableif it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
684def normalize_name( 685 identifier: str | exp.Identifier, 686 dialect: DialectType = None, 687 is_table: bool = False, 688 normalize: t.Optional[bool] = True, 689) -> exp.Identifier: 690 if isinstance(identifier, str): 691 identifier = exp.parse_identifier(identifier, dialect=dialect) 692 693 if not normalize: 694 return identifier 695 696 # this is used for normalize_identifier, bigquery has special rules pertaining tables 697 identifier.meta["is_table"] = is_table 698 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
708def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 709 if mapping is None: 710 return {} 711 elif isinstance(mapping, dict): 712 return mapping 713 elif isinstance(mapping, str): 714 col_name_type_strs = [x.strip() for x in mapping.split(",")] 715 return { 716 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 717 for name_type_str in col_name_type_strs 718 } 719 elif isinstance(mapping, list): 720 return {x.strip(): None for x in mapping} 721 722 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
725def flatten_schema( 726 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 727) -> t.List[t.List[str]]: 728 tables = [] 729 keys = keys or [] 730 depth = dict_depth(schema) - 1 if depth is None else depth 731 732 for k, v in schema.items(): 733 if depth == 1 or not isinstance(v, dict): 734 tables.append(keys + [k]) 735 elif depth >= 2: 736 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 737 738 return tables
741def nested_get( 742 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 743) -> t.Optional[t.Any]: 744 """ 745 Get a value for a nested dictionary. 746 747 Args: 748 d: the dictionary to search. 749 *path: tuples of (name, key), where: 750 `key` is the key in the dictionary to get. 751 `name` is a string to use in the error if `key` isn't found. 752 753 Returns: 754 The value or None if it doesn't exist. 755 """ 756 result: t.Any = d 757 for name, key in path: 758 result = result.get(key) 759 if result is None: 760 if raise_on_missing: 761 name = "table" if name == "this" else name 762 raise ValueError(f"Unknown {name}: {key}") 763 return None 764 765 return result
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
keyis the key in the dictionary to get.nameis a string to use in the error ifkeyisn't found.
Returns:
The value or None if it doesn't exist.
768def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 769 """ 770 In-place set a value for a nested dictionary 771 772 Example: 773 >>> nested_set({}, ["top_key", "second_key"], "value") 774 {'top_key': {'second_key': 'value'}} 775 776 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 777 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 778 779 Args: 780 d: dictionary to update. 781 keys: the keys that makeup the path to `value`. 782 value: the value to set in the dictionary for the given key path. 783 784 Returns: 785 The (possibly) updated dictionary. 786 """ 787 if not keys: 788 return d 789 790 if len(keys) == 1: 791 d[keys[0]] = value 792 return d 793 794 subd = d 795 for key in keys[:-1]: 796 if key not in subd: 797 subd = subd.setdefault(key, {}) 798 else: 799 subd = subd[key] 800 801 subd[keys[-1]] = value 802 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.