Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12from sqlglot.helper import trait
 13
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import SchemaArgs
 17    from sqlglot.dialects.dialect import DialectType
 18    from collections.abc import Sequence
 19    from typing_extensions import Unpack
 20
 21    ColumnMapping = t.Union[dict, str, list]
 22
 23
 24@trait
 25class Schema(abc.ABC):
 26    """Abstract base class for database schemas"""
 27
 28    @property
 29    def dialect(self) -> Dialect | None:
 30        """
 31        Returns None by default. Subclasses that require dialect-specific
 32        behavior should override this property.
 33        """
 34        return None
 35
 36    @abc.abstractmethod
 37    def add_table(
 38        self,
 39        table: exp.Table | str,
 40        column_mapping: ColumnMapping | None = None,
 41        dialect: DialectType = None,
 42        normalize: bool | None = None,
 43        match_depth: bool = True,
 44    ) -> None:
 45        """
 46        Register or update a table. Some implementing classes may require column information to also be provided.
 47        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 48
 49        Args:
 50            table: the `Table` expression instance or string representing the table.
 51            column_mapping: a column mapping that describes the structure of the table.
 52            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 53            normalize: whether to normalize identifiers according to the dialect of interest.
 54            match_depth: whether to enforce that the table must match the schema's depth or not.
 55        """
 56
 57    @abc.abstractmethod
 58    def column_names(
 59        self,
 60        table: exp.Table | str,
 61        only_visible: bool = False,
 62        dialect: DialectType = None,
 63        normalize: bool | None = None,
 64    ) -> Sequence[str]:
 65        """
 66        Get the column names for a table.
 67
 68        Args:
 69            table: the `Table` expression instance.
 70            only_visible: whether to include invisible columns.
 71            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 72            normalize: whether to normalize identifiers according to the dialect of interest.
 73
 74        Returns:
 75            The sequence of column names.
 76        """
 77
 78    @abc.abstractmethod
 79    def get_column_type(
 80        self,
 81        table: exp.Table | str,
 82        column: exp.Column | str,
 83        dialect: DialectType = None,
 84        normalize: bool | None = None,
 85    ) -> exp.DataType:
 86        """
 87        Get the `sqlglot.exp.DataType` type of a column in the schema.
 88
 89        Args:
 90            table: the source table.
 91            column: the target column.
 92            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 93            normalize: whether to normalize identifiers according to the dialect of interest.
 94
 95        Returns:
 96            The resulting column type.
 97        """
 98
 99    def has_column(
100        self,
101        table: exp.Table | str,
102        column: exp.Column | str,
103        dialect: DialectType = None,
104        normalize: bool | None = None,
105    ) -> bool:
106        """
107        Returns whether `column` appears in `table`'s schema.
108
109        Args:
110            table: the source table.
111            column: the target column.
112            dialect: the SQL dialect that will be used to parse `table` if it's a string.
113            normalize: whether to normalize identifiers according to the dialect of interest.
114
115        Returns:
116            True if the column appears in the schema, False otherwise.
117        """
118        name = column if isinstance(column, str) else column.name
119        return name in self.column_names(table, dialect=dialect, normalize=normalize)
120
121    def get_udf_type(
122        self,
123        udf: exp.Anonymous | str,
124        dialect: DialectType = None,
125        normalize: bool | None = None,
126    ) -> exp.DataType:
127        """
128        Get the return type of a UDF.
129
130        Args:
131            udf: the UDF expression or string.
132            dialect: the SQL dialect for parsing string arguments.
133            normalize: whether to normalize identifiers.
134
135        Returns:
136            The return type as a DataType, or UNKNOWN if not found.
137        """
138        return exp.DType.UNKNOWN.into_expr()
139
140    @property
141    @abc.abstractmethod
142    def supported_table_args(self) -> tuple[str, ...]:
143        """
144        Table arguments this schema support, e.g. `("this", "db", "catalog")`
145        """
146
147    @property
148    def empty(self) -> bool:
149        """Returns whether the schema is empty."""
150        return True
151
152
153class AbstractMappingSchema:
154    def __init__(
155        self,
156        mapping: dict[str, object] | None = None,
157        udf_mapping: dict[str, object] | None = None,
158    ) -> None:
159        self.mapping: dict[str, object] = mapping or {}
160        self.mapping_trie: dict[str, object] = new_trie(
161            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
162        )
163        self.udf_mapping: dict[str, object] = udf_mapping or {}
164        self.udf_trie: dict[str, object] = new_trie(
165            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
166        )
167
168        self._supported_table_args: tuple[str, ...] = tuple()
169
170    @property
171    def empty(self) -> bool:
172        return not self.mapping
173
174    def depth(self) -> int:
175        return dict_depth(self.mapping)
176
177    def udf_depth(self) -> int:
178        return dict_depth(self.udf_mapping)
179
180    @property
181    def supported_table_args(self) -> tuple[str, ...]:
182        if not self._supported_table_args and self.mapping:
183            depth = self.depth()
184
185            if not depth:  # None
186                self._supported_table_args = tuple()
187            elif 1 <= depth <= 3:
188                self._supported_table_args = exp.TABLE_PARTS[:depth]
189            else:
190                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
191
192        return self._supported_table_args
193
194    def table_parts(self, table: exp.Table) -> list[str]:
195        return [p.name for p in reversed(table.parts)]
196
197    def udf_parts(self, udf: exp.Anonymous) -> list[str]:
198        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
199        parent = udf.parent
200        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
201        return list(reversed(parts))[0 : self.udf_depth()]
202
203    def _find_in_trie(
204        self,
205        parts: list[str],
206        trie: dict[str, object],
207        raise_on_missing: bool,
208    ) -> list[str] | None:
209        value, trie = in_trie(trie, parts)
210
211        if value == TrieResult.FAILED:
212            return None
213
214        if value == TrieResult.PREFIX:
215            possibilities = flatten_schema(trie)
216
217            if len(possibilities) == 1:
218                parts.extend(possibilities[0])
219            else:
220                if raise_on_missing:
221                    joined_parts = ".".join(parts)
222                    message = ", ".join(".".join(p) for p in possibilities)
223                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
224
225                return None
226
227        return parts
228
229    def find(
230        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
231    ) -> t.Any | None:
232        """
233        Returns the schema of a given table.
234
235        Args:
236            table: the target table.
237            raise_on_missing: whether to raise in case the schema is not found.
238            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
239
240        Returns:
241            The schema of the target table.
242        """
243        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
244        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
245
246        if resolved_parts is None:
247            return None
248
249        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
250
251    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None:
252        """
253        Returns the return type of a given UDF.
254
255        Args:
256            udf: the target UDF expression.
257            raise_on_missing: whether to raise if the UDF is not found.
258
259        Returns:
260            The return type of the UDF, or None if not found.
261        """
262        parts = self.udf_parts(udf)
263        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
264
265        if resolved_parts is None:
266            return None
267
268        return nested_get(
269            self.udf_mapping,
270            *zip(resolved_parts, reversed(resolved_parts)),
271            raise_on_missing=raise_on_missing,
272        )
273
274    def nested_get(
275        self,
276        parts: Sequence[str],
277        d: dict[str, object] | None = None,
278        raise_on_missing: bool = True,
279    ) -> t.Any | None:
280        return nested_get(
281            d or self.mapping,
282            *zip(self.supported_table_args, reversed(parts)),
283            raise_on_missing=raise_on_missing,
284        )
285
286
287class MappingSchema(AbstractMappingSchema, Schema):
288    """
289    Schema based on a nested mapping.
290
291    Args:
292        schema: Mapping in one of the following forms:
293            1. {table: {col: type}}
294            2. {db: {table: {col: type}}}
295            3. {catalog: {db: {table: {col: type}}}}
296            4. None - Tables will be added later
297        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
298            are assumed to be visible. The nesting should mirror that of the schema:
299            1. {table: set(*cols)}}
300            2. {db: {table: set(*cols)}}}
301            3. {catalog: {db: {table: set(*cols)}}}}
302        dialect: The dialect to be used for custom type mappings & parsing string arguments.
303        normalize: Whether to normalize identifier names according to the given dialect or not.
304    """
305
306    def __init__(
307        self,
308        schema: dict[str, object] | None = None,
309        visible: dict[str, object] | None = None,
310        dialect: DialectType = None,
311        normalize: bool = True,
312        udf_mapping: dict[str, object] | None = None,
313    ) -> None:
314        self.visible: dict[str, object] = {} if visible is None else visible
315        self.normalize: bool = normalize
316        self._dialect: Dialect = Dialect.get_or_raise(dialect)
317        self._type_mapping_cache: dict[str, exp.DataType] = {}
318        self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {}
319        self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {}
320        self._depth: int = 0
321        schema = {} if schema is None else schema
322        udf_mapping = {} if udf_mapping is None else udf_mapping
323
324        super().__init__(
325            self._normalize(schema) if self.normalize else schema,
326            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
327        )
328
329    @property
330    def dialect(self) -> Dialect:
331        """Returns the dialect for this mapping schema."""
332        return self._dialect
333
334    @classmethod
335    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
336        return MappingSchema(
337            schema=mapping_schema.mapping,
338            visible=mapping_schema.visible,
339            dialect=mapping_schema.dialect,
340            normalize=mapping_schema.normalize,
341            udf_mapping=mapping_schema.udf_mapping,
342        )
343
344    def find(
345        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
346    ) -> t.Any | None:
347        schema = super().find(
348            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
349        )
350        if ensure_data_types and isinstance(schema, dict):
351            schema = {
352                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
353                for col, dtype in schema.items()
354            }
355
356        return schema
357
358    def copy(
359        self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs]
360    ) -> MappingSchema:
361        mapping_kwargs: SchemaArgs = {
362            "visible": self.visible.copy(),
363            "dialect": self.dialect,
364            "normalize": self.normalize,
365            "udf_mapping": self.udf_mapping.copy(),
366            **kwargs,
367        }
368        return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
369
370    def add_table(
371        self,
372        table: exp.Table | str,
373        column_mapping: ColumnMapping | None = None,
374        dialect: DialectType = None,
375        normalize: bool | None = None,
376        match_depth: bool = True,
377    ) -> None:
378        """
379        Register or update a table. Updates are only performed if a new column mapping is provided.
380        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
381
382        Args:
383            table: the `Table` expression instance or string representing the table.
384            column_mapping: a column mapping that describes the structure of the table.
385            dialect: the SQL dialect that will be used to parse `table` if it's a string.
386            normalize: whether to normalize identifiers according to the dialect of interest.
387            match_depth: whether to enforce that the table must match the schema's depth or not.
388        """
389        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
390
391        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
392            raise SchemaError(
393                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
394                f"schema's nesting level: {self.depth()}."
395            )
396
397        normalized_column_mapping = {
398            self._normalize_name(key, dialect=dialect, normalize=normalize): value
399            for key, value in ensure_column_mapping(column_mapping).items()
400        }
401
402        schema = self.find(normalized_table, raise_on_missing=False)
403        if schema and not normalized_column_mapping:
404            return
405
406        parts = self.table_parts(normalized_table)
407
408        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
409        new_trie([parts], self.mapping_trie)
410
411    def column_names(
412        self,
413        table: exp.Table | str,
414        only_visible: bool = False,
415        dialect: DialectType = None,
416        normalize: bool | None = None,
417    ) -> list[str]:
418        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
419
420        schema = self.find(normalized_table)
421        if schema is None:
422            return []
423
424        if not only_visible or not self.visible:
425            return list(schema)
426
427        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
428        return [col for col in schema if col in visible]
429
430    def get_column_type(
431        self,
432        table: exp.Table | str,
433        column: exp.Column | str,
434        dialect: DialectType = None,
435        normalize: bool | None = None,
436    ) -> exp.DataType:
437        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
438
439        normalized_column_name = self._normalize_name(
440            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
441        )
442
443        table_schema = self.find(normalized_table, raise_on_missing=False)
444        if table_schema:
445            column_type = table_schema.get(normalized_column_name)
446
447            if isinstance(column_type, exp.DataType):
448                return column_type
449            elif isinstance(column_type, str):
450                return self._to_data_type(column_type, dialect=dialect)
451
452        return exp.DType.UNKNOWN.into_expr()
453
454    def get_udf_type(
455        self,
456        udf: exp.Anonymous | str,
457        dialect: DialectType = None,
458        normalize: bool | None = None,
459    ) -> exp.DataType:
460        """
461        Get the return type of a UDF.
462
463        Args:
464            udf: the UDF expression or string (e.g., "db.my_func()").
465            dialect: the SQL dialect for parsing string arguments.
466            normalize: whether to normalize identifiers.
467
468        Returns:
469            The return type as a DataType, or UNKNOWN if not found.
470        """
471        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
472        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
473
474        if resolved_parts is None:
475            return exp.DType.UNKNOWN.into_expr()
476
477        udf_type = nested_get(
478            self.udf_mapping,
479            *zip(resolved_parts, reversed(resolved_parts)),
480            raise_on_missing=False,
481        )
482
483        if isinstance(udf_type, exp.DataType):
484            return udf_type
485        elif isinstance(udf_type, str):
486            return self._to_data_type(udf_type, dialect=dialect)
487
488        return exp.DType.UNKNOWN.into_expr()
489
490    def has_column(
491        self,
492        table: exp.Table | str,
493        column: exp.Column | str,
494        dialect: DialectType = None,
495        normalize: bool | None = None,
496    ) -> bool:
497        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
498
499        normalized_column_name = self._normalize_name(
500            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
501        )
502
503        table_schema = self.find(normalized_table, raise_on_missing=False)
504        return normalized_column_name in table_schema if table_schema else False
505
506    def _normalize(self, schema: dict[str, object]) -> dict[str, object]:
507        """
508        Normalizes all identifiers in the schema.
509
510        Args:
511            schema: the schema to normalize.
512
513        Returns:
514            The normalized schema mapping.
515        """
516        normalized_mapping: dict[str, object] = {}
517        flattened_schema = flatten_schema(schema)
518        error_msg = "Table {} must match the schema's nesting level: {}."
519
520        for keys in flattened_schema:
521            columns = nested_get(schema, *zip(keys, keys))
522
523            if not isinstance(columns, dict):
524                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
525            if not columns:
526                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
527            if isinstance(first(columns.values()), dict):
528                raise SchemaError(
529                    error_msg.format(
530                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
531                    ),
532                )
533
534            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
535            for column_name, column_type in columns.items():
536                nested_set(
537                    normalized_mapping,
538                    normalized_keys + [self._normalize_name(column_name)],
539                    column_type,
540                )
541
542        return normalized_mapping
543
544    def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]:
545        """
546        Normalizes all identifiers in the UDF mapping.
547
548        Args:
549            udfs: the UDF mapping to normalize.
550
551        Returns:
552            The normalized UDF mapping.
553        """
554        normalized_mapping: dict[str, object] = {}
555
556        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
557            udf_type = nested_get(udfs, *zip(keys, keys))
558            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
559            nested_set(normalized_mapping, normalized_keys, udf_type)
560
561        return normalized_mapping
562
563    def _normalize_udf(
564        self,
565        udf: exp.Anonymous | str,
566        dialect: DialectType = None,
567        normalize: bool | None = None,
568    ) -> list[str]:
569        """
570        Extract and normalize UDF parts for lookup.
571
572        Args:
573            udf: the UDF expression or qualified string (e.g., "db.my_func()").
574            dialect: the SQL dialect for parsing.
575            normalize: whether to normalize identifiers.
576
577        Returns:
578            A list of normalized UDF parts (reversed for trie lookup).
579        """
580        dialect = dialect or self.dialect
581        normalize = self.normalize if normalize is None else normalize
582
583        if isinstance(udf, str):
584            parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect)
585
586            if isinstance(parsed, exp.Anonymous):
587                udf = parsed
588            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
589                udf = parsed.expression
590            else:
591                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
592        parts = self.udf_parts(udf)
593
594        if normalize:
595            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
596
597        return parts
598
599    def _normalize_table(
600        self,
601        table: exp.Table | str,
602        dialect: DialectType = None,
603        normalize: bool | None = None,
604    ) -> exp.Table:
605        dialect = dialect or self.dialect
606        normalize = self.normalize if normalize is None else normalize
607
608        # Cache normalized tables by object id for exp.Table inputs
609        # This is effective when the same Table object is looked up multiple times
610        if isinstance(table, exp.Table) and (
611            cached := self._normalized_table_cache.get((table, dialect, normalize))
612        ):
613            return cached
614
615        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
616
617        if normalize:
618            for part in normalized_table.parts:
619                if isinstance(part, exp.Identifier):
620                    part.replace(
621                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
622                    )
623
624        self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table
625        return normalized_table
626
627    def _normalize_name(
628        self,
629        name: str | exp.Identifier,
630        dialect: DialectType = None,
631        is_table: bool = False,
632        normalize: bool | None = None,
633    ) -> str:
634        normalize = self.normalize if normalize is None else normalize
635
636        dialect = dialect or self.dialect
637        name_str = name if isinstance(name, str) else name.name
638        cache_key = (name_str, dialect, is_table, normalize)
639
640        if cached := self._normalized_name_cache.get(cache_key):
641            return cached
642
643        result = normalize_name(
644            name,
645            dialect=dialect,
646            is_table=is_table,
647            normalize=normalize,
648        ).name
649
650        self._normalized_name_cache[cache_key] = result
651        return result
652
653    def depth(self) -> int:
654        if not self.empty and not self._depth:
655            # The columns themselves are a mapping, but we don't want to include those
656            self._depth = super().depth() - 1
657        return self._depth
658
659    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
660        """
661        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
662
663        Args:
664            schema_type: the type we want to convert.
665            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
666
667        Returns:
668            The resulting expression type.
669        """
670        if schema_type not in self._type_mapping_cache:
671            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
672            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
673
674            try:
675                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
676                expression.transform(dialect.normalize_identifier, copy=False)
677                self._type_mapping_cache[schema_type] = expression
678            except AttributeError:
679                in_dialect = f" in dialect {dialect}" if dialect else ""
680                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
681
682        return self._type_mapping_cache[schema_type]
683
684
685def normalize_name(
686    identifier: str | exp.Identifier,
687    dialect: DialectType = None,
688    is_table: bool = False,
689    normalize: bool | None = True,
690) -> exp.Identifier:
691    if isinstance(identifier, str):
692        identifier = exp.parse_identifier(identifier, dialect=dialect)
693
694    if not normalize:
695        return identifier
696
697    # this is used for normalize_identifier, bigquery has special rules pertaining tables
698    identifier.meta["is_table"] = is_table
699    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
700
701
702def ensure_schema(
703    schema: Schema | dict[str, object] | None, **kwargs: Unpack[SchemaArgs]
704) -> Schema:
705    if isinstance(schema, Schema):
706        return schema
707
708    return MappingSchema(schema, **kwargs)
709
710
711def ensure_column_mapping(mapping: ColumnMapping | None) -> dict:
712    if mapping is None:
713        return {}
714    elif isinstance(mapping, dict):
715        return mapping
716    elif isinstance(mapping, str):
717        col_name_type_strs = [x.strip() for x in mapping.split(",")]
718        return {
719            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
720            for name_type_str in col_name_type_strs
721        }
722    elif isinstance(mapping, list):
723        return {x.strip(): None for x in mapping}
724
725    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
726
727
728def flatten_schema(
729    schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None
730) -> list[list[str]]:
731    tables: list[list[str]] = []
732    keys = keys or []
733    depth = dict_depth(schema) - 1 if depth is None else depth
734
735    for k, v in schema.items():
736        if depth == 1 or not isinstance(v, dict):
737            tables.append(keys + [k])
738        elif depth >= 2:
739            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
740
741    return tables
742
743
744def nested_get(
745    d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True
746) -> t.Any | None:
747    """
748    Get a value for a nested dictionary.
749
750    Args:
751        d: the dictionary to search.
752        *path: tuples of (name, key), where:
753            `key` is the key in the dictionary to get.
754            `name` is a string to use in the error if `key` isn't found.
755
756    Returns:
757        The value or None if it doesn't exist.
758    """
759    result: t.Any = d
760    for name, key in path:
761        result = result.get(key)
762        if result is None:
763            if raise_on_missing:
764                name = "table" if name == "this" else name
765                raise ValueError(f"Unknown {name}: {key}")
766            return None
767
768    return result
769
770
771def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]:
772    """
773    In-place set a value for a nested dictionary
774
775    Example:
776        >>> nested_set({}, ["top_key", "second_key"], "value")
777        {'top_key': {'second_key': 'value'}}
778
779        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
780        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
781
782    Args:
783        d: dictionary to update.
784        keys: the keys that makeup the path to `value`.
785        value: the value to set in the dictionary for the given key path.
786
787    Returns:
788        The (possibly) updated dictionary.
789    """
790    if not keys:
791        return d
792
793    if len(keys) == 1:
794        d[keys[0]] = value
795        return d
796
797    subd = d
798    for key in keys[:-1]:
799        if key not in subd:
800            subd = subd.setdefault(key, {})
801        else:
802            subd = subd[key]
803
804    subd[keys[-1]] = value
805    return d
@trait
class Schema(abc.ABC):
 25@trait
 26class Schema(abc.ABC):
 27    """Abstract base class for database schemas"""
 28
 29    @property
 30    def dialect(self) -> Dialect | None:
 31        """
 32        Returns None by default. Subclasses that require dialect-specific
 33        behavior should override this property.
 34        """
 35        return None
 36
 37    @abc.abstractmethod
 38    def add_table(
 39        self,
 40        table: exp.Table | str,
 41        column_mapping: ColumnMapping | None = None,
 42        dialect: DialectType = None,
 43        normalize: bool | None = None,
 44        match_depth: bool = True,
 45    ) -> None:
 46        """
 47        Register or update a table. Some implementing classes may require column information to also be provided.
 48        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 49
 50        Args:
 51            table: the `Table` expression instance or string representing the table.
 52            column_mapping: a column mapping that describes the structure of the table.
 53            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 54            normalize: whether to normalize identifiers according to the dialect of interest.
 55            match_depth: whether to enforce that the table must match the schema's depth or not.
 56        """
 57
 58    @abc.abstractmethod
 59    def column_names(
 60        self,
 61        table: exp.Table | str,
 62        only_visible: bool = False,
 63        dialect: DialectType = None,
 64        normalize: bool | None = None,
 65    ) -> Sequence[str]:
 66        """
 67        Get the column names for a table.
 68
 69        Args:
 70            table: the `Table` expression instance.
 71            only_visible: whether to include invisible columns.
 72            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 73            normalize: whether to normalize identifiers according to the dialect of interest.
 74
 75        Returns:
 76            The sequence of column names.
 77        """
 78
 79    @abc.abstractmethod
 80    def get_column_type(
 81        self,
 82        table: exp.Table | str,
 83        column: exp.Column | str,
 84        dialect: DialectType = None,
 85        normalize: bool | None = None,
 86    ) -> exp.DataType:
 87        """
 88        Get the `sqlglot.exp.DataType` type of a column in the schema.
 89
 90        Args:
 91            table: the source table.
 92            column: the target column.
 93            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 94            normalize: whether to normalize identifiers according to the dialect of interest.
 95
 96        Returns:
 97            The resulting column type.
 98        """
 99
100    def has_column(
101        self,
102        table: exp.Table | str,
103        column: exp.Column | str,
104        dialect: DialectType = None,
105        normalize: bool | None = None,
106    ) -> bool:
107        """
108        Returns whether `column` appears in `table`'s schema.
109
110        Args:
111            table: the source table.
112            column: the target column.
113            dialect: the SQL dialect that will be used to parse `table` if it's a string.
114            normalize: whether to normalize identifiers according to the dialect of interest.
115
116        Returns:
117            True if the column appears in the schema, False otherwise.
118        """
119        name = column if isinstance(column, str) else column.name
120        return name in self.column_names(table, dialect=dialect, normalize=normalize)
121
122    def get_udf_type(
123        self,
124        udf: exp.Anonymous | str,
125        dialect: DialectType = None,
126        normalize: bool | None = None,
127    ) -> exp.DataType:
128        """
129        Get the return type of a UDF.
130
131        Args:
132            udf: the UDF expression or string.
133            dialect: the SQL dialect for parsing string arguments.
134            normalize: whether to normalize identifiers.
135
136        Returns:
137            The return type as a DataType, or UNKNOWN if not found.
138        """
139        return exp.DType.UNKNOWN.into_expr()
140
141    @property
142    @abc.abstractmethod
143    def supported_table_args(self) -> tuple[str, ...]:
144        """
145        Table arguments this schema support, e.g. `("this", "db", "catalog")`
146        """
147
148    @property
149    def empty(self) -> bool:
150        """Returns whether the schema is empty."""
151        return True

Abstract base class for database schemas

dialect: sqlglot.dialects.Dialect | None
29    @property
30    def dialect(self) -> Dialect | None:
31        """
32        Returns None by default. Subclasses that require dialect-specific
33        behavior should override this property.
34        """
35        return None

Returns None by default. Subclasses that require dialect-specific behavior should override this property.

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.query.Table | str, column_mapping: Union[dict, str, list, NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None, match_depth: bool = True) -> None:
37    @abc.abstractmethod
38    def add_table(
39        self,
40        table: exp.Table | str,
41        column_mapping: ColumnMapping | None = None,
42        dialect: DialectType = None,
43        normalize: bool | None = None,
44        match_depth: bool = True,
45    ) -> None:
46        """
47        Register or update a table. Some implementing classes may require column information to also be provided.
48        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
49
50        Args:
51            table: the `Table` expression instance or string representing the table.
52            column_mapping: a column mapping that describes the structure of the table.
53            dialect: the SQL dialect that will be used to parse `table` if it's a string.
54            normalize: whether to normalize identifiers according to the dialect of interest.
55            match_depth: whether to enforce that the table must match the schema's depth or not.
56        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.query.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> Sequence[str]:
58    @abc.abstractmethod
59    def column_names(
60        self,
61        table: exp.Table | str,
62        only_visible: bool = False,
63        dialect: DialectType = None,
64        normalize: bool | None = None,
65    ) -> Sequence[str]:
66        """
67        Get the column names for a table.
68
69        Args:
70            table: the `Table` expression instance.
71            only_visible: whether to include invisible columns.
72            dialect: the SQL dialect that will be used to parse `table` if it's a string.
73            normalize: whether to normalize identifiers according to the dialect of interest.
74
75        Returns:
76            The sequence of column names.
77        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
79    @abc.abstractmethod
80    def get_column_type(
81        self,
82        table: exp.Table | str,
83        column: exp.Column | str,
84        dialect: DialectType = None,
85        normalize: bool | None = None,
86    ) -> exp.DataType:
87        """
88        Get the `sqlglot.exp.DataType` type of a column in the schema.
89
90        Args:
91            table: the source table.
92            column: the target column.
93            dialect: the SQL dialect that will be used to parse `table` if it's a string.
94            normalize: whether to normalize identifiers according to the dialect of interest.
95
96        Returns:
97            The resulting column type.
98        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> bool:
100    def has_column(
101        self,
102        table: exp.Table | str,
103        column: exp.Column | str,
104        dialect: DialectType = None,
105        normalize: bool | None = None,
106    ) -> bool:
107        """
108        Returns whether `column` appears in `table`'s schema.
109
110        Args:
111            table: the source table.
112            column: the target column.
113            dialect: the SQL dialect that will be used to parse `table` if it's a string.
114            normalize: whether to normalize identifiers according to the dialect of interest.
115
116        Returns:
117            True if the column appears in the schema, False otherwise.
118        """
119        name = column if isinstance(column, str) else column.name
120        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def get_udf_type( self, udf: sqlglot.expressions.core.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
122    def get_udf_type(
123        self,
124        udf: exp.Anonymous | str,
125        dialect: DialectType = None,
126        normalize: bool | None = None,
127    ) -> exp.DataType:
128        """
129        Get the return type of a UDF.
130
131        Args:
132            udf: the UDF expression or string.
133            dialect: the SQL dialect for parsing string arguments.
134            normalize: whether to normalize identifiers.
135
136        Returns:
137            The return type as a DataType, or UNKNOWN if not found.
138        """
139        return exp.DType.UNKNOWN.into_expr()

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string.
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

supported_table_args: tuple[str, ...]
141    @property
142    @abc.abstractmethod
143    def supported_table_args(self) -> tuple[str, ...]:
144        """
145        Table arguments this schema support, e.g. `("this", "db", "catalog")`
146        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
148    @property
149    def empty(self) -> bool:
150        """Returns whether the schema is empty."""
151        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
154class AbstractMappingSchema:
155    def __init__(
156        self,
157        mapping: dict[str, object] | None = None,
158        udf_mapping: dict[str, object] | None = None,
159    ) -> None:
160        self.mapping: dict[str, object] = mapping or {}
161        self.mapping_trie: dict[str, object] = new_trie(
162            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
163        )
164        self.udf_mapping: dict[str, object] = udf_mapping or {}
165        self.udf_trie: dict[str, object] = new_trie(
166            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
167        )
168
169        self._supported_table_args: tuple[str, ...] = tuple()
170
171    @property
172    def empty(self) -> bool:
173        return not self.mapping
174
175    def depth(self) -> int:
176        return dict_depth(self.mapping)
177
178    def udf_depth(self) -> int:
179        return dict_depth(self.udf_mapping)
180
181    @property
182    def supported_table_args(self) -> tuple[str, ...]:
183        if not self._supported_table_args and self.mapping:
184            depth = self.depth()
185
186            if not depth:  # None
187                self._supported_table_args = tuple()
188            elif 1 <= depth <= 3:
189                self._supported_table_args = exp.TABLE_PARTS[:depth]
190            else:
191                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
192
193        return self._supported_table_args
194
195    def table_parts(self, table: exp.Table) -> list[str]:
196        return [p.name for p in reversed(table.parts)]
197
198    def udf_parts(self, udf: exp.Anonymous) -> list[str]:
199        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
200        parent = udf.parent
201        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
202        return list(reversed(parts))[0 : self.udf_depth()]
203
204    def _find_in_trie(
205        self,
206        parts: list[str],
207        trie: dict[str, object],
208        raise_on_missing: bool,
209    ) -> list[str] | None:
210        value, trie = in_trie(trie, parts)
211
212        if value == TrieResult.FAILED:
213            return None
214
215        if value == TrieResult.PREFIX:
216            possibilities = flatten_schema(trie)
217
218            if len(possibilities) == 1:
219                parts.extend(possibilities[0])
220            else:
221                if raise_on_missing:
222                    joined_parts = ".".join(parts)
223                    message = ", ".join(".".join(p) for p in possibilities)
224                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
225
226                return None
227
228        return parts
229
230    def find(
231        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
232    ) -> t.Any | None:
233        """
234        Returns the schema of a given table.
235
236        Args:
237            table: the target table.
238            raise_on_missing: whether to raise in case the schema is not found.
239            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
240
241        Returns:
242            The schema of the target table.
243        """
244        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
245        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
246
247        if resolved_parts is None:
248            return None
249
250        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
251
252    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None:
253        """
254        Returns the return type of a given UDF.
255
256        Args:
257            udf: the target UDF expression.
258            raise_on_missing: whether to raise if the UDF is not found.
259
260        Returns:
261            The return type of the UDF, or None if not found.
262        """
263        parts = self.udf_parts(udf)
264        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
265
266        if resolved_parts is None:
267            return None
268
269        return nested_get(
270            self.udf_mapping,
271            *zip(resolved_parts, reversed(resolved_parts)),
272            raise_on_missing=raise_on_missing,
273        )
274
275    def nested_get(
276        self,
277        parts: Sequence[str],
278        d: dict[str, object] | None = None,
279        raise_on_missing: bool = True,
280    ) -> t.Any | None:
281        return nested_get(
282            d or self.mapping,
283            *zip(self.supported_table_args, reversed(parts)),
284            raise_on_missing=raise_on_missing,
285        )
AbstractMappingSchema( mapping: dict[str, object] | None = None, udf_mapping: dict[str, object] | None = None)
155    def __init__(
156        self,
157        mapping: dict[str, object] | None = None,
158        udf_mapping: dict[str, object] | None = None,
159    ) -> None:
160        self.mapping: dict[str, object] = mapping or {}
161        self.mapping_trie: dict[str, object] = new_trie(
162            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
163        )
164        self.udf_mapping: dict[str, object] = udf_mapping or {}
165        self.udf_trie: dict[str, object] = new_trie(
166            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
167        )
168
169        self._supported_table_args: tuple[str, ...] = tuple()
mapping: dict[str, object]
mapping_trie: dict[str, object]
udf_mapping: dict[str, object]
udf_trie: dict[str, object]
empty: bool
171    @property
172    def empty(self) -> bool:
173        return not self.mapping
def depth(self) -> int:
175    def depth(self) -> int:
176        return dict_depth(self.mapping)
def udf_depth(self) -> int:
178    def udf_depth(self) -> int:
179        return dict_depth(self.udf_mapping)
supported_table_args: tuple[str, ...]
181    @property
182    def supported_table_args(self) -> tuple[str, ...]:
183        if not self._supported_table_args and self.mapping:
184            depth = self.depth()
185
186            if not depth:  # None
187                self._supported_table_args = tuple()
188            elif 1 <= depth <= 3:
189                self._supported_table_args = exp.TABLE_PARTS[:depth]
190            else:
191                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
192
193        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.query.Table) -> list[str]:
195    def table_parts(self, table: exp.Table) -> list[str]:
196        return [p.name for p in reversed(table.parts)]
def udf_parts(self, udf: sqlglot.expressions.core.Anonymous) -> list[str]:
198    def udf_parts(self, udf: exp.Anonymous) -> list[str]:
199        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
200        parent = udf.parent
201        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
202        return list(reversed(parts))[0 : self.udf_depth()]
def find( self, table: sqlglot.expressions.query.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
230    def find(
231        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
232    ) -> t.Any | None:
233        """
234        Returns the schema of a given table.
235
236        Args:
237            table: the target table.
238            raise_on_missing: whether to raise in case the schema is not found.
239            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
240
241        Returns:
242            The schema of the target table.
243        """
244        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
245        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
246
247        if resolved_parts is None:
248            return None
249
250        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def find_udf( self, udf: sqlglot.expressions.core.Anonymous, raise_on_missing: bool = False) -> Optional[Any]:
252    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None:
253        """
254        Returns the return type of a given UDF.
255
256        Args:
257            udf: the target UDF expression.
258            raise_on_missing: whether to raise if the UDF is not found.
259
260        Returns:
261            The return type of the UDF, or None if not found.
262        """
263        parts = self.udf_parts(udf)
264        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
265
266        if resolved_parts is None:
267            return None
268
269        return nested_get(
270            self.udf_mapping,
271            *zip(resolved_parts, reversed(resolved_parts)),
272            raise_on_missing=raise_on_missing,
273        )

Returns the return type of a given UDF.

Arguments:
  • udf: the target UDF expression.
  • raise_on_missing: whether to raise if the UDF is not found.
Returns:

The return type of the UDF, or None if not found.

def nested_get( self, parts: Sequence[str], d: dict[str, object] | None = None, raise_on_missing: bool = True) -> Optional[Any]:
275    def nested_get(
276        self,
277        parts: Sequence[str],
278        d: dict[str, object] | None = None,
279        raise_on_missing: bool = True,
280    ) -> t.Any | None:
281        return nested_get(
282            d or self.mapping,
283            *zip(self.supported_table_args, reversed(parts)),
284            raise_on_missing=raise_on_missing,
285        )
class MappingSchema(AbstractMappingSchema, Schema):
288class MappingSchema(AbstractMappingSchema, Schema):
289    """
290    Schema based on a nested mapping.
291
292    Args:
293        schema: Mapping in one of the following forms:
294            1. {table: {col: type}}
295            2. {db: {table: {col: type}}}
296            3. {catalog: {db: {table: {col: type}}}}
297            4. None - Tables will be added later
298        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
299            are assumed to be visible. The nesting should mirror that of the schema:
300            1. {table: set(*cols)}}
301            2. {db: {table: set(*cols)}}}
302            3. {catalog: {db: {table: set(*cols)}}}}
303        dialect: The dialect to be used for custom type mappings & parsing string arguments.
304        normalize: Whether to normalize identifier names according to the given dialect or not.
305    """
306
307    def __init__(
308        self,
309        schema: dict[str, object] | None = None,
310        visible: dict[str, object] | None = None,
311        dialect: DialectType = None,
312        normalize: bool = True,
313        udf_mapping: dict[str, object] | None = None,
314    ) -> None:
315        self.visible: dict[str, object] = {} if visible is None else visible
316        self.normalize: bool = normalize
317        self._dialect: Dialect = Dialect.get_or_raise(dialect)
318        self._type_mapping_cache: dict[str, exp.DataType] = {}
319        self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {}
320        self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {}
321        self._depth: int = 0
322        schema = {} if schema is None else schema
323        udf_mapping = {} if udf_mapping is None else udf_mapping
324
325        super().__init__(
326            self._normalize(schema) if self.normalize else schema,
327            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
328        )
329
330    @property
331    def dialect(self) -> Dialect:
332        """Returns the dialect for this mapping schema."""
333        return self._dialect
334
335    @classmethod
336    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
337        return MappingSchema(
338            schema=mapping_schema.mapping,
339            visible=mapping_schema.visible,
340            dialect=mapping_schema.dialect,
341            normalize=mapping_schema.normalize,
342            udf_mapping=mapping_schema.udf_mapping,
343        )
344
345    def find(
346        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
347    ) -> t.Any | None:
348        schema = super().find(
349            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
350        )
351        if ensure_data_types and isinstance(schema, dict):
352            schema = {
353                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
354                for col, dtype in schema.items()
355            }
356
357        return schema
358
359    def copy(
360        self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs]
361    ) -> MappingSchema:
362        mapping_kwargs: SchemaArgs = {
363            "visible": self.visible.copy(),
364            "dialect": self.dialect,
365            "normalize": self.normalize,
366            "udf_mapping": self.udf_mapping.copy(),
367            **kwargs,
368        }
369        return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
370
371    def add_table(
372        self,
373        table: exp.Table | str,
374        column_mapping: ColumnMapping | None = None,
375        dialect: DialectType = None,
376        normalize: bool | None = None,
377        match_depth: bool = True,
378    ) -> None:
379        """
380        Register or update a table. Updates are only performed if a new column mapping is provided.
381        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
382
383        Args:
384            table: the `Table` expression instance or string representing the table.
385            column_mapping: a column mapping that describes the structure of the table.
386            dialect: the SQL dialect that will be used to parse `table` if it's a string.
387            normalize: whether to normalize identifiers according to the dialect of interest.
388            match_depth: whether to enforce that the table must match the schema's depth or not.
389        """
390        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
391
392        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
393            raise SchemaError(
394                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
395                f"schema's nesting level: {self.depth()}."
396            )
397
398        normalized_column_mapping = {
399            self._normalize_name(key, dialect=dialect, normalize=normalize): value
400            for key, value in ensure_column_mapping(column_mapping).items()
401        }
402
403        schema = self.find(normalized_table, raise_on_missing=False)
404        if schema and not normalized_column_mapping:
405            return
406
407        parts = self.table_parts(normalized_table)
408
409        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
410        new_trie([parts], self.mapping_trie)
411
412    def column_names(
413        self,
414        table: exp.Table | str,
415        only_visible: bool = False,
416        dialect: DialectType = None,
417        normalize: bool | None = None,
418    ) -> list[str]:
419        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
420
421        schema = self.find(normalized_table)
422        if schema is None:
423            return []
424
425        if not only_visible or not self.visible:
426            return list(schema)
427
428        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
429        return [col for col in schema if col in visible]
430
431    def get_column_type(
432        self,
433        table: exp.Table | str,
434        column: exp.Column | str,
435        dialect: DialectType = None,
436        normalize: bool | None = None,
437    ) -> exp.DataType:
438        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
439
440        normalized_column_name = self._normalize_name(
441            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
442        )
443
444        table_schema = self.find(normalized_table, raise_on_missing=False)
445        if table_schema:
446            column_type = table_schema.get(normalized_column_name)
447
448            if isinstance(column_type, exp.DataType):
449                return column_type
450            elif isinstance(column_type, str):
451                return self._to_data_type(column_type, dialect=dialect)
452
453        return exp.DType.UNKNOWN.into_expr()
454
455    def get_udf_type(
456        self,
457        udf: exp.Anonymous | str,
458        dialect: DialectType = None,
459        normalize: bool | None = None,
460    ) -> exp.DataType:
461        """
462        Get the return type of a UDF.
463
464        Args:
465            udf: the UDF expression or string (e.g., "db.my_func()").
466            dialect: the SQL dialect for parsing string arguments.
467            normalize: whether to normalize identifiers.
468
469        Returns:
470            The return type as a DataType, or UNKNOWN if not found.
471        """
472        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
473        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
474
475        if resolved_parts is None:
476            return exp.DType.UNKNOWN.into_expr()
477
478        udf_type = nested_get(
479            self.udf_mapping,
480            *zip(resolved_parts, reversed(resolved_parts)),
481            raise_on_missing=False,
482        )
483
484        if isinstance(udf_type, exp.DataType):
485            return udf_type
486        elif isinstance(udf_type, str):
487            return self._to_data_type(udf_type, dialect=dialect)
488
489        return exp.DType.UNKNOWN.into_expr()
490
491    def has_column(
492        self,
493        table: exp.Table | str,
494        column: exp.Column | str,
495        dialect: DialectType = None,
496        normalize: bool | None = None,
497    ) -> bool:
498        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
499
500        normalized_column_name = self._normalize_name(
501            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
502        )
503
504        table_schema = self.find(normalized_table, raise_on_missing=False)
505        return normalized_column_name in table_schema if table_schema else False
506
507    def _normalize(self, schema: dict[str, object]) -> dict[str, object]:
508        """
509        Normalizes all identifiers in the schema.
510
511        Args:
512            schema: the schema to normalize.
513
514        Returns:
515            The normalized schema mapping.
516        """
517        normalized_mapping: dict[str, object] = {}
518        flattened_schema = flatten_schema(schema)
519        error_msg = "Table {} must match the schema's nesting level: {}."
520
521        for keys in flattened_schema:
522            columns = nested_get(schema, *zip(keys, keys))
523
524            if not isinstance(columns, dict):
525                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
526            if not columns:
527                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
528            if isinstance(first(columns.values()), dict):
529                raise SchemaError(
530                    error_msg.format(
531                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
532                    ),
533                )
534
535            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
536            for column_name, column_type in columns.items():
537                nested_set(
538                    normalized_mapping,
539                    normalized_keys + [self._normalize_name(column_name)],
540                    column_type,
541                )
542
543        return normalized_mapping
544
545    def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]:
546        """
547        Normalizes all identifiers in the UDF mapping.
548
549        Args:
550            udfs: the UDF mapping to normalize.
551
552        Returns:
553            The normalized UDF mapping.
554        """
555        normalized_mapping: dict[str, object] = {}
556
557        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
558            udf_type = nested_get(udfs, *zip(keys, keys))
559            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
560            nested_set(normalized_mapping, normalized_keys, udf_type)
561
562        return normalized_mapping
563
564    def _normalize_udf(
565        self,
566        udf: exp.Anonymous | str,
567        dialect: DialectType = None,
568        normalize: bool | None = None,
569    ) -> list[str]:
570        """
571        Extract and normalize UDF parts for lookup.
572
573        Args:
574            udf: the UDF expression or qualified string (e.g., "db.my_func()").
575            dialect: the SQL dialect for parsing.
576            normalize: whether to normalize identifiers.
577
578        Returns:
579            A list of normalized UDF parts (reversed for trie lookup).
580        """
581        dialect = dialect or self.dialect
582        normalize = self.normalize if normalize is None else normalize
583
584        if isinstance(udf, str):
585            parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect)
586
587            if isinstance(parsed, exp.Anonymous):
588                udf = parsed
589            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
590                udf = parsed.expression
591            else:
592                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
593        parts = self.udf_parts(udf)
594
595        if normalize:
596            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
597
598        return parts
599
600    def _normalize_table(
601        self,
602        table: exp.Table | str,
603        dialect: DialectType = None,
604        normalize: bool | None = None,
605    ) -> exp.Table:
606        dialect = dialect or self.dialect
607        normalize = self.normalize if normalize is None else normalize
608
609        # Cache normalized tables by object id for exp.Table inputs
610        # This is effective when the same Table object is looked up multiple times
611        if isinstance(table, exp.Table) and (
612            cached := self._normalized_table_cache.get((table, dialect, normalize))
613        ):
614            return cached
615
616        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
617
618        if normalize:
619            for part in normalized_table.parts:
620                if isinstance(part, exp.Identifier):
621                    part.replace(
622                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
623                    )
624
625        self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table
626        return normalized_table
627
628    def _normalize_name(
629        self,
630        name: str | exp.Identifier,
631        dialect: DialectType = None,
632        is_table: bool = False,
633        normalize: bool | None = None,
634    ) -> str:
635        normalize = self.normalize if normalize is None else normalize
636
637        dialect = dialect or self.dialect
638        name_str = name if isinstance(name, str) else name.name
639        cache_key = (name_str, dialect, is_table, normalize)
640
641        if cached := self._normalized_name_cache.get(cache_key):
642            return cached
643
644        result = normalize_name(
645            name,
646            dialect=dialect,
647            is_table=is_table,
648            normalize=normalize,
649        ).name
650
651        self._normalized_name_cache[cache_key] = result
652        return result
653
654    def depth(self) -> int:
655        if not self.empty and not self._depth:
656            # The columns themselves are a mapping, but we don't want to include those
657            self._depth = super().depth() - 1
658        return self._depth
659
660    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
661        """
662        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
663
664        Args:
665            schema_type: the type we want to convert.
666            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
667
668        Returns:
669            The resulting expression type.
670        """
671        if schema_type not in self._type_mapping_cache:
672            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
673            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
674
675            try:
676                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
677                expression.transform(dialect.normalize_identifier, copy=False)
678                self._type_mapping_cache[schema_type] = expression
679            except AttributeError:
680                in_dialect = f" in dialect {dialect}" if dialect else ""
681                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
682
683        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(*cols)}}
    2. {db: {table: set(*cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: dict[str, object] | None = None, visible: dict[str, object] | None = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool = True, udf_mapping: dict[str, object] | None = None)
307    def __init__(
308        self,
309        schema: dict[str, object] | None = None,
310        visible: dict[str, object] | None = None,
311        dialect: DialectType = None,
312        normalize: bool = True,
313        udf_mapping: dict[str, object] | None = None,
314    ) -> None:
315        self.visible: dict[str, object] = {} if visible is None else visible
316        self.normalize: bool = normalize
317        self._dialect: Dialect = Dialect.get_or_raise(dialect)
318        self._type_mapping_cache: dict[str, exp.DataType] = {}
319        self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {}
320        self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {}
321        self._depth: int = 0
322        schema = {} if schema is None else schema
323        udf_mapping = {} if udf_mapping is None else udf_mapping
324
325        super().__init__(
326            self._normalize(schema) if self.normalize else schema,
327            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
328        )
visible: dict[str, object]
normalize: bool
dialect: sqlglot.dialects.Dialect
330    @property
331    def dialect(self) -> Dialect:
332        """Returns the dialect for this mapping schema."""
333        return self._dialect

Returns the dialect for this mapping schema.

@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
335    @classmethod
336    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
337        return MappingSchema(
338            schema=mapping_schema.mapping,
339            visible=mapping_schema.visible,
340            dialect=mapping_schema.dialect,
341            normalize=mapping_schema.normalize,
342            udf_mapping=mapping_schema.udf_mapping,
343        )
def find( self, table: sqlglot.expressions.query.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
345    def find(
346        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
347    ) -> t.Any | None:
348        schema = super().find(
349            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
350        )
351        if ensure_data_types and isinstance(schema, dict):
352            schema = {
353                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
354                for col, dtype in schema.items()
355            }
356
357        return schema

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def copy( self, schema: dict[str, object] | None = None, **kwargs: typing_extensions.Unpack[sqlglot._typing.SchemaArgs]) -> MappingSchema:
359    def copy(
360        self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs]
361    ) -> MappingSchema:
362        mapping_kwargs: SchemaArgs = {
363            "visible": self.visible.copy(),
364            "dialect": self.dialect,
365            "normalize": self.normalize,
366            "udf_mapping": self.udf_mapping.copy(),
367            **kwargs,
368        }
369        return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
def add_table( self, table: sqlglot.expressions.query.Table | str, column_mapping: Union[dict, str, list, NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None, match_depth: bool = True) -> None:
371    def add_table(
372        self,
373        table: exp.Table | str,
374        column_mapping: ColumnMapping | None = None,
375        dialect: DialectType = None,
376        normalize: bool | None = None,
377        match_depth: bool = True,
378    ) -> None:
379        """
380        Register or update a table. Updates are only performed if a new column mapping is provided.
381        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
382
383        Args:
384            table: the `Table` expression instance or string representing the table.
385            column_mapping: a column mapping that describes the structure of the table.
386            dialect: the SQL dialect that will be used to parse `table` if it's a string.
387            normalize: whether to normalize identifiers according to the dialect of interest.
388            match_depth: whether to enforce that the table must match the schema's depth or not.
389        """
390        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
391
392        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
393            raise SchemaError(
394                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
395                f"schema's nesting level: {self.depth()}."
396            )
397
398        normalized_column_mapping = {
399            self._normalize_name(key, dialect=dialect, normalize=normalize): value
400            for key, value in ensure_column_mapping(column_mapping).items()
401        }
402
403        schema = self.find(normalized_table, raise_on_missing=False)
404        if schema and not normalized_column_mapping:
405            return
406
407        parts = self.table_parts(normalized_table)
408
409        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
410        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.query.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> list[str]:
412    def column_names(
413        self,
414        table: exp.Table | str,
415        only_visible: bool = False,
416        dialect: DialectType = None,
417        normalize: bool | None = None,
418    ) -> list[str]:
419        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
420
421        schema = self.find(normalized_table)
422        if schema is None:
423            return []
424
425        if not only_visible or not self.visible:
426            return list(schema)
427
428        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
429        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
431    def get_column_type(
432        self,
433        table: exp.Table | str,
434        column: exp.Column | str,
435        dialect: DialectType = None,
436        normalize: bool | None = None,
437    ) -> exp.DataType:
438        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
439
440        normalized_column_name = self._normalize_name(
441            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
442        )
443
444        table_schema = self.find(normalized_table, raise_on_missing=False)
445        if table_schema:
446            column_type = table_schema.get(normalized_column_name)
447
448            if isinstance(column_type, exp.DataType):
449                return column_type
450            elif isinstance(column_type, str):
451                return self._to_data_type(column_type, dialect=dialect)
452
453        return exp.DType.UNKNOWN.into_expr()

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def get_udf_type( self, udf: sqlglot.expressions.core.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
455    def get_udf_type(
456        self,
457        udf: exp.Anonymous | str,
458        dialect: DialectType = None,
459        normalize: bool | None = None,
460    ) -> exp.DataType:
461        """
462        Get the return type of a UDF.
463
464        Args:
465            udf: the UDF expression or string (e.g., "db.my_func()").
466            dialect: the SQL dialect for parsing string arguments.
467            normalize: whether to normalize identifiers.
468
469        Returns:
470            The return type as a DataType, or UNKNOWN if not found.
471        """
472        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
473        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
474
475        if resolved_parts is None:
476            return exp.DType.UNKNOWN.into_expr()
477
478        udf_type = nested_get(
479            self.udf_mapping,
480            *zip(resolved_parts, reversed(resolved_parts)),
481            raise_on_missing=False,
482        )
483
484        if isinstance(udf_type, exp.DataType):
485            return udf_type
486        elif isinstance(udf_type, str):
487            return self._to_data_type(udf_type, dialect=dialect)
488
489        return exp.DType.UNKNOWN.into_expr()

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string (e.g., "db.my_func()").
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

def has_column( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> bool:
491    def has_column(
492        self,
493        table: exp.Table | str,
494        column: exp.Column | str,
495        dialect: DialectType = None,
496        normalize: bool | None = None,
497    ) -> bool:
498        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
499
500        normalized_column_name = self._normalize_name(
501            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
502        )
503
504        table_schema = self.find(normalized_table, raise_on_missing=False)
505        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
654    def depth(self) -> int:
655        if not self.empty and not self._depth:
656            # The columns themselves are a mapping, but we don't want to include those
657            self._depth = super().depth() - 1
658        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.core.Identifier, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, is_table: bool = False, normalize: bool | None = True) -> sqlglot.expressions.core.Identifier:
686def normalize_name(
687    identifier: str | exp.Identifier,
688    dialect: DialectType = None,
689    is_table: bool = False,
690    normalize: bool | None = True,
691) -> exp.Identifier:
692    if isinstance(identifier, str):
693        identifier = exp.parse_identifier(identifier, dialect=dialect)
694
695    if not normalize:
696        return identifier
697
698    # this is used for normalize_identifier, bigquery has special rules pertaining tables
699    identifier.meta["is_table"] = is_table
700    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Schema | dict[str, object] | None, **kwargs: typing_extensions.Unpack[sqlglot._typing.SchemaArgs]) -> Schema:
703def ensure_schema(
704    schema: Schema | dict[str, object] | None, **kwargs: Unpack[SchemaArgs]
705) -> Schema:
706    if isinstance(schema, Schema):
707        return schema
708
709    return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: Union[dict, str, list, NoneType]) -> dict:
712def ensure_column_mapping(mapping: ColumnMapping | None) -> dict:
713    if mapping is None:
714        return {}
715    elif isinstance(mapping, dict):
716        return mapping
717    elif isinstance(mapping, str):
718        col_name_type_strs = [x.strip() for x in mapping.split(",")]
719        return {
720            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
721            for name_type_str in col_name_type_strs
722        }
723    elif isinstance(mapping, list):
724        return {x.strip(): None for x in mapping}
725
726    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None) -> list[list[str]]:
729def flatten_schema(
730    schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None
731) -> list[list[str]]:
732    tables: list[list[str]] = []
733    keys = keys or []
734    depth = dict_depth(schema) - 1 if depth is None else depth
735
736    for k, v in schema.items():
737        if depth == 1 or not isinstance(v, dict):
738            tables.append(keys + [k])
739        elif depth >= 2:
740            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
741
742    return tables
def nested_get( d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
745def nested_get(
746    d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True
747) -> t.Any | None:
748    """
749    Get a value for a nested dictionary.
750
751    Args:
752        d: the dictionary to search.
753        *path: tuples of (name, key), where:
754            `key` is the key in the dictionary to get.
755            `name` is a string to use in the error if `key` isn't found.
756
757    Returns:
758        The value or None if it doesn't exist.
759    """
760    result: t.Any = d
761    for name, key in path:
762        result = result.get(key)
763        if result is None:
764            if raise_on_missing:
765                name = "table" if name == "this" else name
766                raise ValueError(f"Unknown {name}: {key}")
767            return None
768
769    return result

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set( d: dict[str, typing.Any], keys: Sequence[str], value: Any) -> dict[str, typing.Any]:
772def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]:
773    """
774    In-place set a value for a nested dictionary
775
776    Example:
777        >>> nested_set({}, ["top_key", "second_key"], "value")
778        {'top_key': {'second_key': 'value'}}
779
780        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
781        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
782
783    Args:
784        d: dictionary to update.
785        keys: the keys that makeup the path to `value`.
786        value: the value to set in the dictionary for the given key path.
787
788    Returns:
789        The (possibly) updated dictionary.
790    """
791    if not keys:
792        return d
793
794    if len(keys) == 1:
795        d[keys[0]] = value
796        return d
797
798    subd = d
799    for key in keys[:-1]:
800        if key not in subd:
801            subd = subd.setdefault(key, {})
802        else:
803            subd = subd[key]
804
805    subd[keys[-1]] = value
806    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.