Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12from sqlglot.helper import trait
 13
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import SchemaArgs
 17    from sqlglot.dialects.dialect import DialectType
 18    from collections.abc import Sequence
 19    from typing_extensions import Unpack
 20
 21    ColumnMapping = t.Union[dict[str, t.Any], str, list[str]]
 22
 23
 24@trait
 25class Schema(abc.ABC):
 26    """Abstract base class for database schemas"""
 27
 28    @property
 29    def dialect(self) -> Dialect | None:
 30        """
 31        Returns None by default. Subclasses that require dialect-specific
 32        behavior should override this property.
 33        """
 34        return None
 35
 36    @abc.abstractmethod
 37    def add_table(
 38        self,
 39        table: exp.Table | str,
 40        column_mapping: ColumnMapping | None = None,
 41        dialect: DialectType = None,
 42        normalize: bool | None = None,
 43        match_depth: bool = True,
 44    ) -> None:
 45        """
 46        Register or update a table. Some implementing classes may require column information to also be provided.
 47        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 48
 49        Args:
 50            table: the `Table` expression instance or string representing the table.
 51            column_mapping: a column mapping that describes the structure of the table.
 52            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 53            normalize: whether to normalize identifiers according to the dialect of interest.
 54            match_depth: whether to enforce that the table must match the schema's depth or not.
 55        """
 56
 57    @abc.abstractmethod
 58    def column_names(
 59        self,
 60        table: exp.Table | str,
 61        only_visible: bool = False,
 62        dialect: DialectType = None,
 63        normalize: bool | None = None,
 64    ) -> Sequence[str]:
 65        """
 66        Get the column names for a table.
 67
 68        Args:
 69            table: the `Table` expression instance.
 70            only_visible: whether to include invisible columns.
 71            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 72            normalize: whether to normalize identifiers according to the dialect of interest.
 73
 74        Returns:
 75            The sequence of column names.
 76        """
 77
 78    @abc.abstractmethod
 79    def get_column_type(
 80        self,
 81        table: exp.Table | str,
 82        column: exp.Column | str,
 83        dialect: DialectType = None,
 84        normalize: bool | None = None,
 85    ) -> exp.DataType:
 86        """
 87        Get the `sqlglot.exp.DataType` type of a column in the schema.
 88
 89        Args:
 90            table: the source table.
 91            column: the target column.
 92            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 93            normalize: whether to normalize identifiers according to the dialect of interest.
 94
 95        Returns:
 96            The resulting column type.
 97        """
 98
 99    def has_column(
100        self,
101        table: exp.Table | str,
102        column: exp.Column | str,
103        dialect: DialectType = None,
104        normalize: bool | None = None,
105    ) -> bool:
106        """
107        Returns whether `column` appears in `table`'s schema.
108
109        Args:
110            table: the source table.
111            column: the target column.
112            dialect: the SQL dialect that will be used to parse `table` if it's a string.
113            normalize: whether to normalize identifiers according to the dialect of interest.
114
115        Returns:
116            True if the column appears in the schema, False otherwise.
117        """
118        name = column if isinstance(column, str) else column.name
119        return name in self.column_names(table, dialect=dialect, normalize=normalize)
120
121    def get_udf_type(
122        self,
123        udf: exp.Anonymous | str,
124        dialect: DialectType = None,
125        normalize: bool | None = None,
126    ) -> exp.DataType:
127        """
128        Get the return type of a UDF.
129
130        Args:
131            udf: the UDF expression or string.
132            dialect: the SQL dialect for parsing string arguments.
133            normalize: whether to normalize identifiers.
134
135        Returns:
136            The return type as a DataType, or UNKNOWN if not found.
137        """
138        return exp.DType.UNKNOWN.into_expr()
139
140    @property
141    @abc.abstractmethod
142    def supported_table_args(self) -> tuple[str, ...]:
143        """
144        Table arguments this schema support, e.g. `("this", "db", "catalog")`
145        """
146
147    @property
148    def empty(self) -> bool:
149        """Returns whether the schema is empty."""
150        return True
151
152
153class AbstractMappingSchema:
154    def __init__(
155        self,
156        mapping: dict[str, object] | None = None,
157        udf_mapping: dict[str, object] | None = None,
158    ) -> None:
159        self.mapping: dict[str, object] = mapping or {}
160        self.mapping_trie: dict[str, object] = new_trie(
161            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
162        )
163        self.udf_mapping: dict[str, object] = udf_mapping or {}
164        self.udf_trie: dict[str, object] = new_trie(
165            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
166        )
167
168        self._supported_table_args: tuple[str, ...] = tuple()
169
170    @property
171    def empty(self) -> bool:
172        return not self.mapping
173
174    def depth(self) -> int:
175        return dict_depth(self.mapping)
176
177    def udf_depth(self) -> int:
178        return dict_depth(self.udf_mapping)
179
180    @property
181    def supported_table_args(self) -> tuple[str, ...]:
182        if not self._supported_table_args and self.mapping:
183            depth = self.depth()
184
185            if not depth:  # None
186                self._supported_table_args = tuple()
187            elif 1 <= depth <= 3:
188                self._supported_table_args = exp.TABLE_PARTS[:depth]
189            else:
190                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
191
192        return self._supported_table_args
193
194    def table_parts(self, table: exp.Table) -> list[str]:
195        return [p.name for p in reversed(table.parts)]
196
197    def udf_parts(self, udf: exp.Anonymous) -> list[str]:
198        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
199        parent = udf.parent
200        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
201        return list(reversed(parts))[0 : self.udf_depth()]
202
203    def _find_in_trie(
204        self,
205        parts: list[str],
206        trie: dict[str, object],
207        raise_on_missing: bool,
208    ) -> list[str] | None:
209        value, trie = in_trie(trie, parts)
210
211        if value == TrieResult.FAILED:
212            return None
213
214        if value == TrieResult.PREFIX:
215            possibilities = flatten_schema(trie)
216
217            if len(possibilities) == 1:
218                parts.extend(possibilities[0])
219            else:
220                if raise_on_missing:
221                    joined_parts = ".".join(parts)
222                    message = ", ".join(".".join(p) for p in possibilities)
223                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
224
225                return None
226
227        return parts
228
229    def find(
230        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
231    ) -> t.Any | None:
232        """
233        Returns the schema of a given table.
234
235        Args:
236            table: the target table.
237            raise_on_missing: whether to raise in case the schema is not found.
238            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
239
240        Returns:
241            The schema of the target table.
242        """
243        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
244        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
245
246        if resolved_parts is None:
247            return None
248
249        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
250
251    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None:
252        """
253        Returns the return type of a given UDF.
254
255        Args:
256            udf: the target UDF expression.
257            raise_on_missing: whether to raise if the UDF is not found.
258
259        Returns:
260            The return type of the UDF, or None if not found.
261        """
262        parts = self.udf_parts(udf)
263        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
264
265        if resolved_parts is None:
266            return None
267
268        return nested_get(
269            self.udf_mapping,
270            *zip(resolved_parts, reversed(resolved_parts)),
271            raise_on_missing=raise_on_missing,
272        )
273
274    def nested_get(
275        self,
276        parts: Sequence[str],
277        d: dict[str, object] | None = None,
278        raise_on_missing: bool = True,
279    ) -> t.Any | None:
280        return nested_get(
281            d or self.mapping,
282            *zip(self.supported_table_args, reversed(parts)),
283            raise_on_missing=raise_on_missing,
284        )
285
286
287class MappingSchema(AbstractMappingSchema, Schema):
288    """
289    Schema based on a nested mapping.
290
291    Args:
292        schema: Mapping in one of the following forms:
293            1. {table: {col: type}}
294            2. {db: {table: {col: type}}}
295            3. {catalog: {db: {table: {col: type}}}}
296            4. None - Tables will be added later
297        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
298            are assumed to be visible. The nesting should mirror that of the schema:
299            1. {table: set(*cols)}}
300            2. {db: {table: set(*cols)}}}
301            3. {catalog: {db: {table: set(*cols)}}}}
302        dialect: The dialect to be used for custom type mappings & parsing string arguments.
303        normalize: Whether to normalize identifier names according to the given dialect or not.
304    """
305
306    def __init__(
307        self,
308        schema: dict[str, object] | None = None,
309        visible: dict[str, object] | None = None,
310        dialect: DialectType = None,
311        normalize: bool = True,
312        udf_mapping: dict[str, object] | None = None,
313    ) -> None:
314        self.visible: dict[str, object] = {} if visible is None else visible
315        self.normalize: bool = normalize
316        self._dialect: Dialect = Dialect.get_or_raise(dialect)
317        self._type_mapping_cache: dict[str, exp.DataType] = {}
318        self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {}
319        self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {}
320        self._find_cache: dict[tuple[exp.Table, bool], dict[str, object] | None] = {}
321        self._depth: int = 0
322        schema = {} if schema is None else schema
323        udf_mapping = {} if udf_mapping is None else udf_mapping
324
325        super().__init__(
326            self._normalize(schema) if self.normalize else schema,
327            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
328        )
329
330    @property
331    def dialect(self) -> Dialect:
332        """Returns the dialect for this mapping schema."""
333        return self._dialect
334
335    @classmethod
336    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
337        return MappingSchema(
338            schema=mapping_schema.mapping,
339            visible=mapping_schema.visible,
340            dialect=mapping_schema.dialect,
341            normalize=mapping_schema.normalize,
342            udf_mapping=mapping_schema.udf_mapping,
343        )
344
345    def find(
346        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
347    ) -> t.Any | None:
348        cache_key = (table, ensure_data_types)
349        schema = self._find_cache.get(cache_key)
350
351        if schema is None:
352            schema = super().find(table, raise_on_missing=raise_on_missing)
353            if ensure_data_types and isinstance(schema, dict):
354                schema = {
355                    col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
356                    for col, dtype in schema.items()
357                }
358            self._find_cache[cache_key] = schema
359
360        return schema
361
362    def copy(
363        self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs]
364    ) -> MappingSchema:
365        mapping_kwargs: SchemaArgs = {
366            "visible": self.visible.copy(),
367            "dialect": self.dialect,
368            "normalize": self.normalize,
369            "udf_mapping": self.udf_mapping.copy(),
370            **kwargs,
371        }
372        return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
373
374    def add_table(
375        self,
376        table: exp.Table | str,
377        column_mapping: ColumnMapping | None = None,
378        dialect: DialectType = None,
379        normalize: bool | None = None,
380        match_depth: bool = True,
381    ) -> None:
382        """
383        Register or update a table. Updates are only performed if a new column mapping is provided.
384        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
385
386        Args:
387            table: the `Table` expression instance or string representing the table.
388            column_mapping: a column mapping that describes the structure of the table.
389            dialect: the SQL dialect that will be used to parse `table` if it's a string.
390            normalize: whether to normalize identifiers according to the dialect of interest.
391            match_depth: whether to enforce that the table must match the schema's depth or not.
392        """
393        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
394
395        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
396            raise SchemaError(
397                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
398                f"schema's nesting level: {self.depth()}."
399            )
400
401        normalized_column_mapping = {
402            self._normalize_name(key, dialect=dialect, normalize=normalize): value
403            for key, value in ensure_column_mapping(column_mapping).items()
404        }
405
406        schema = self.find(normalized_table, raise_on_missing=False)
407        if schema and not normalized_column_mapping:
408            return
409
410        parts = self.table_parts(normalized_table)
411
412        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
413        new_trie([parts], self.mapping_trie)
414        self._find_cache.pop((normalized_table, True), None)
415        self._find_cache.pop((normalized_table, False), None)
416
417    def column_names(
418        self,
419        table: exp.Table | str,
420        only_visible: bool = False,
421        dialect: DialectType = None,
422        normalize: bool | None = None,
423    ) -> list[str]:
424        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
425
426        schema: dict[str, object] | None = self.find(normalized_table)
427        if schema is None:
428            return []
429
430        if not only_visible or not self.visible:
431            return list(schema)
432
433        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
434        return [col for col in schema if col in visible]
435
436    def get_column_type(
437        self,
438        table: exp.Table | str,
439        column: exp.Column | str,
440        dialect: DialectType = None,
441        normalize: bool | None = None,
442    ) -> exp.DataType:
443        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
444
445        normalized_column_name = self._normalize_name(
446            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
447        )
448
449        table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
450        if table_schema:
451            column_type = table_schema.get(normalized_column_name)
452
453            if isinstance(column_type, exp.DataType):
454                return column_type
455            elif isinstance(column_type, str):
456                return self._to_data_type(column_type, dialect=dialect)
457
458        return exp.DType.UNKNOWN.into_expr()
459
460    def get_udf_type(
461        self,
462        udf: exp.Anonymous | str,
463        dialect: DialectType = None,
464        normalize: bool | None = None,
465    ) -> exp.DataType:
466        """
467        Get the return type of a UDF.
468
469        Args:
470            udf: the UDF expression or string (e.g., "db.my_func()").
471            dialect: the SQL dialect for parsing string arguments.
472            normalize: whether to normalize identifiers.
473
474        Returns:
475            The return type as a DataType, or UNKNOWN if not found.
476        """
477        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
478        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
479
480        if resolved_parts is None:
481            return exp.DType.UNKNOWN.into_expr()
482
483        udf_type = nested_get(
484            self.udf_mapping,
485            *zip(resolved_parts, reversed(resolved_parts)),
486            raise_on_missing=False,
487        )
488
489        if isinstance(udf_type, exp.DataType):
490            return udf_type
491        elif isinstance(udf_type, str):
492            return self._to_data_type(udf_type, dialect=dialect)
493
494        return exp.DType.UNKNOWN.into_expr()
495
496    def has_column(
497        self,
498        table: exp.Table | str,
499        column: exp.Column | str,
500        dialect: DialectType = None,
501        normalize: bool | None = None,
502    ) -> bool:
503        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
504
505        normalized_column_name = self._normalize_name(
506            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
507        )
508
509        table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
510        return normalized_column_name in table_schema if table_schema else False
511
512    def _normalize(self, schema: dict[str, object]) -> dict[str, object]:
513        """
514        Normalizes all identifiers in the schema.
515
516        Args:
517            schema: the schema to normalize.
518
519        Returns:
520            The normalized schema mapping.
521        """
522        normalized_mapping: dict[str, object] = {}
523        flattened_schema = flatten_schema(schema)
524        error_msg = "Table {} must match the schema's nesting level: {}."
525
526        for keys in flattened_schema:
527            columns = nested_get(schema, *zip(keys, keys))
528
529            if not isinstance(columns, dict):
530                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
531            if not columns:
532                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
533            if isinstance(first(columns.values()), dict):
534                raise SchemaError(
535                    error_msg.format(
536                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
537                    ),
538                )
539
540            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
541            for column_name, column_type in columns.items():
542                nested_set(
543                    normalized_mapping,
544                    normalized_keys + [self._normalize_name(column_name)],
545                    column_type,
546                )
547
548        return normalized_mapping
549
550    def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]:
551        """
552        Normalizes all identifiers in the UDF mapping.
553
554        Args:
555            udfs: the UDF mapping to normalize.
556
557        Returns:
558            The normalized UDF mapping.
559        """
560        normalized_mapping: dict[str, object] = {}
561
562        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
563            udf_type = nested_get(udfs, *zip(keys, keys))
564            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
565            nested_set(normalized_mapping, normalized_keys, udf_type)
566
567        return normalized_mapping
568
569    def _normalize_udf(
570        self,
571        udf: exp.Anonymous | str,
572        dialect: DialectType = None,
573        normalize: bool | None = None,
574    ) -> list[str]:
575        """
576        Extract and normalize UDF parts for lookup.
577
578        Args:
579            udf: the UDF expression or qualified string (e.g., "db.my_func()").
580            dialect: the SQL dialect for parsing.
581            normalize: whether to normalize identifiers.
582
583        Returns:
584            A list of normalized UDF parts (reversed for trie lookup).
585        """
586        dialect = dialect or self.dialect
587        normalize = self.normalize if normalize is None else normalize
588
589        if isinstance(udf, str):
590            parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect)
591
592            if isinstance(parsed, exp.Anonymous):
593                udf = parsed
594            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
595                udf = parsed.expression
596            else:
597                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
598        parts = self.udf_parts(udf)
599
600        if normalize:
601            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
602
603        return parts
604
605    def _normalize_table(
606        self,
607        table: exp.Table | str,
608        dialect: DialectType = None,
609        normalize: bool | None = None,
610    ) -> exp.Table:
611        dialect = dialect or self.dialect
612        normalize = self.normalize if normalize is None else normalize
613
614        # Cache normalized tables by object id for exp.Table inputs
615        # This is effective when the same Table object is looked up multiple times
616        if isinstance(table, exp.Table) and (
617            cached := self._normalized_table_cache.get((table, dialect, normalize))
618        ):
619            return cached
620
621        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
622
623        if normalize:
624            for part in normalized_table.parts:
625                if isinstance(part, exp.Identifier):
626                    part.replace(
627                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
628                    )
629
630        self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table
631        return normalized_table
632
633    def _normalize_name(
634        self,
635        name: str | exp.Identifier,
636        dialect: DialectType = None,
637        is_table: bool = False,
638        normalize: bool | None = None,
639    ) -> str:
640        normalize = self.normalize if normalize is None else normalize
641
642        dialect = dialect or self.dialect
643        name_str = name if isinstance(name, str) else name.name
644        cache_key = (name_str, dialect, is_table, normalize)
645
646        if cached := self._normalized_name_cache.get(cache_key):
647            return cached
648
649        result = normalize_name(
650            name,
651            dialect=dialect,
652            is_table=is_table,
653            normalize=normalize,
654        ).name
655
656        self._normalized_name_cache[cache_key] = result
657        return result
658
659    def depth(self) -> int:
660        if not self.empty and not self._depth:
661            # The columns themselves are a mapping, but we don't want to include those
662            self._depth = super().depth() - 1
663        return self._depth
664
665    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
666        """
667        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
668
669        Args:
670            schema_type: the type we want to convert.
671            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
672
673        Returns:
674            The resulting expression type.
675        """
676        if schema_type not in self._type_mapping_cache:
677            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
678            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
679
680            try:
681                expression = exp.DataType.from_str(schema_type, dialect=dialect, udt=udt)
682                expression.transform(dialect.normalize_identifier, copy=False)
683                self._type_mapping_cache[schema_type] = expression
684            except AttributeError:
685                in_dialect = f" in dialect {dialect}" if dialect else ""
686                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
687
688        return self._type_mapping_cache[schema_type]
689
690
691def normalize_name(
692    identifier: str | exp.Identifier,
693    dialect: DialectType = None,
694    is_table: bool = False,
695    normalize: bool | None = True,
696) -> exp.Identifier:
697    if isinstance(identifier, str):
698        identifier = exp.parse_identifier(identifier, dialect=dialect)
699
700    if not normalize:
701        return identifier
702
703    # this is used for normalize_identifier, bigquery has special rules pertaining tables
704    identifier.meta["is_table"] = is_table
705    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
706
707
708def ensure_schema(
709    schema: Schema | dict[str, object] | None, **kwargs: Unpack[SchemaArgs]
710) -> Schema:
711    if isinstance(schema, Schema):
712        return schema
713
714    return MappingSchema(schema, **kwargs)
715
716
717def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, t.Any]:
718    if mapping is None:
719        return {}
720    elif isinstance(mapping, dict):
721        return mapping
722    elif isinstance(mapping, str):
723        col_name_type_strs = [x.strip() for x in mapping.split(",")]
724        return {
725            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
726            for name_type_str in col_name_type_strs
727        }
728    elif isinstance(mapping, list):
729        return {x.strip(): None for x in mapping}
730
731    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
732
733
734def flatten_schema(
735    schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None
736) -> list[list[str]]:
737    tables: list[list[str]] = []
738    keys = keys or []
739    depth = dict_depth(schema) - 1 if depth is None else depth
740
741    for k, v in schema.items():
742        if depth == 1 or not isinstance(v, dict):
743            tables.append(keys + [k])
744        elif depth >= 2:
745            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
746
747    return tables
748
749
750def nested_get(
751    d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True
752) -> t.Any | None:
753    """
754    Get a value for a nested dictionary.
755
756    Args:
757        d: the dictionary to search.
758        *path: tuples of (name, key), where:
759            `key` is the key in the dictionary to get.
760            `name` is a string to use in the error if `key` isn't found.
761
762    Returns:
763        The value or None if it doesn't exist.
764    """
765    result: t.Any = d
766    for name, key in path:
767        result = result.get(key)
768        if result is None:
769            if raise_on_missing:
770                name = "table" if name == "this" else name
771                raise ValueError(f"Unknown {name}: {key}")
772            return None
773
774    return result
775
776
777def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]:
778    """
779    In-place set a value for a nested dictionary
780
781    Example:
782        >>> nested_set({}, ["top_key", "second_key"], "value")
783        {'top_key': {'second_key': 'value'}}
784
785        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
786        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
787
788    Args:
789        d: dictionary to update.
790        keys: the keys that makeup the path to `value`.
791        value: the value to set in the dictionary for the given key path.
792
793    Returns:
794        The (possibly) updated dictionary.
795    """
796    if not keys:
797        return d
798
799    if len(keys) == 1:
800        d[keys[0]] = value
801        return d
802
803    subd = d
804    for key in keys[:-1]:
805        if key not in subd:
806            subd = subd.setdefault(key, {})
807        else:
808            subd = subd[key]
809
810    subd[keys[-1]] = value
811    return d
@trait
class Schema(abc.ABC):
 25@trait
 26class Schema(abc.ABC):
 27    """Abstract base class for database schemas"""
 28
 29    @property
 30    def dialect(self) -> Dialect | None:
 31        """
 32        Returns None by default. Subclasses that require dialect-specific
 33        behavior should override this property.
 34        """
 35        return None
 36
 37    @abc.abstractmethod
 38    def add_table(
 39        self,
 40        table: exp.Table | str,
 41        column_mapping: ColumnMapping | None = None,
 42        dialect: DialectType = None,
 43        normalize: bool | None = None,
 44        match_depth: bool = True,
 45    ) -> None:
 46        """
 47        Register or update a table. Some implementing classes may require column information to also be provided.
 48        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 49
 50        Args:
 51            table: the `Table` expression instance or string representing the table.
 52            column_mapping: a column mapping that describes the structure of the table.
 53            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 54            normalize: whether to normalize identifiers according to the dialect of interest.
 55            match_depth: whether to enforce that the table must match the schema's depth or not.
 56        """
 57
 58    @abc.abstractmethod
 59    def column_names(
 60        self,
 61        table: exp.Table | str,
 62        only_visible: bool = False,
 63        dialect: DialectType = None,
 64        normalize: bool | None = None,
 65    ) -> Sequence[str]:
 66        """
 67        Get the column names for a table.
 68
 69        Args:
 70            table: the `Table` expression instance.
 71            only_visible: whether to include invisible columns.
 72            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 73            normalize: whether to normalize identifiers according to the dialect of interest.
 74
 75        Returns:
 76            The sequence of column names.
 77        """
 78
 79    @abc.abstractmethod
 80    def get_column_type(
 81        self,
 82        table: exp.Table | str,
 83        column: exp.Column | str,
 84        dialect: DialectType = None,
 85        normalize: bool | None = None,
 86    ) -> exp.DataType:
 87        """
 88        Get the `sqlglot.exp.DataType` type of a column in the schema.
 89
 90        Args:
 91            table: the source table.
 92            column: the target column.
 93            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 94            normalize: whether to normalize identifiers according to the dialect of interest.
 95
 96        Returns:
 97            The resulting column type.
 98        """
 99
100    def has_column(
101        self,
102        table: exp.Table | str,
103        column: exp.Column | str,
104        dialect: DialectType = None,
105        normalize: bool | None = None,
106    ) -> bool:
107        """
108        Returns whether `column` appears in `table`'s schema.
109
110        Args:
111            table: the source table.
112            column: the target column.
113            dialect: the SQL dialect that will be used to parse `table` if it's a string.
114            normalize: whether to normalize identifiers according to the dialect of interest.
115
116        Returns:
117            True if the column appears in the schema, False otherwise.
118        """
119        name = column if isinstance(column, str) else column.name
120        return name in self.column_names(table, dialect=dialect, normalize=normalize)
121
122    def get_udf_type(
123        self,
124        udf: exp.Anonymous | str,
125        dialect: DialectType = None,
126        normalize: bool | None = None,
127    ) -> exp.DataType:
128        """
129        Get the return type of a UDF.
130
131        Args:
132            udf: the UDF expression or string.
133            dialect: the SQL dialect for parsing string arguments.
134            normalize: whether to normalize identifiers.
135
136        Returns:
137            The return type as a DataType, or UNKNOWN if not found.
138        """
139        return exp.DType.UNKNOWN.into_expr()
140
141    @property
142    @abc.abstractmethod
143    def supported_table_args(self) -> tuple[str, ...]:
144        """
145        Table arguments this schema support, e.g. `("this", "db", "catalog")`
146        """
147
148    @property
149    def empty(self) -> bool:
150        """Returns whether the schema is empty."""
151        return True

Abstract base class for database schemas

dialect: sqlglot.dialects.Dialect | None
29    @property
30    def dialect(self) -> Dialect | None:
31        """
32        Returns None by default. Subclasses that require dialect-specific
33        behavior should override this property.
34        """
35        return None

Returns None by default. Subclasses that require dialect-specific behavior should override this property.

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.query.Table | str, column_mapping: Union[dict[str, Any], str, list[str], NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None, match_depth: bool = True) -> None:
37    @abc.abstractmethod
38    def add_table(
39        self,
40        table: exp.Table | str,
41        column_mapping: ColumnMapping | None = None,
42        dialect: DialectType = None,
43        normalize: bool | None = None,
44        match_depth: bool = True,
45    ) -> None:
46        """
47        Register or update a table. Some implementing classes may require column information to also be provided.
48        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
49
50        Args:
51            table: the `Table` expression instance or string representing the table.
52            column_mapping: a column mapping that describes the structure of the table.
53            dialect: the SQL dialect that will be used to parse `table` if it's a string.
54            normalize: whether to normalize identifiers according to the dialect of interest.
55            match_depth: whether to enforce that the table must match the schema's depth or not.
56        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.query.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> Sequence[str]:
58    @abc.abstractmethod
59    def column_names(
60        self,
61        table: exp.Table | str,
62        only_visible: bool = False,
63        dialect: DialectType = None,
64        normalize: bool | None = None,
65    ) -> Sequence[str]:
66        """
67        Get the column names for a table.
68
69        Args:
70            table: the `Table` expression instance.
71            only_visible: whether to include invisible columns.
72            dialect: the SQL dialect that will be used to parse `table` if it's a string.
73            normalize: whether to normalize identifiers according to the dialect of interest.
74
75        Returns:
76            The sequence of column names.
77        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
79    @abc.abstractmethod
80    def get_column_type(
81        self,
82        table: exp.Table | str,
83        column: exp.Column | str,
84        dialect: DialectType = None,
85        normalize: bool | None = None,
86    ) -> exp.DataType:
87        """
88        Get the `sqlglot.exp.DataType` type of a column in the schema.
89
90        Args:
91            table: the source table.
92            column: the target column.
93            dialect: the SQL dialect that will be used to parse `table` if it's a string.
94            normalize: whether to normalize identifiers according to the dialect of interest.
95
96        Returns:
97            The resulting column type.
98        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> bool:
100    def has_column(
101        self,
102        table: exp.Table | str,
103        column: exp.Column | str,
104        dialect: DialectType = None,
105        normalize: bool | None = None,
106    ) -> bool:
107        """
108        Returns whether `column` appears in `table`'s schema.
109
110        Args:
111            table: the source table.
112            column: the target column.
113            dialect: the SQL dialect that will be used to parse `table` if it's a string.
114            normalize: whether to normalize identifiers according to the dialect of interest.
115
116        Returns:
117            True if the column appears in the schema, False otherwise.
118        """
119        name = column if isinstance(column, str) else column.name
120        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def get_udf_type( self, udf: sqlglot.expressions.core.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
122    def get_udf_type(
123        self,
124        udf: exp.Anonymous | str,
125        dialect: DialectType = None,
126        normalize: bool | None = None,
127    ) -> exp.DataType:
128        """
129        Get the return type of a UDF.
130
131        Args:
132            udf: the UDF expression or string.
133            dialect: the SQL dialect for parsing string arguments.
134            normalize: whether to normalize identifiers.
135
136        Returns:
137            The return type as a DataType, or UNKNOWN if not found.
138        """
139        return exp.DType.UNKNOWN.into_expr()

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string.
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

supported_table_args: tuple[str, ...]
141    @property
142    @abc.abstractmethod
143    def supported_table_args(self) -> tuple[str, ...]:
144        """
145        Table arguments this schema support, e.g. `("this", "db", "catalog")`
146        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
148    @property
149    def empty(self) -> bool:
150        """Returns whether the schema is empty."""
151        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
154class AbstractMappingSchema:
155    def __init__(
156        self,
157        mapping: dict[str, object] | None = None,
158        udf_mapping: dict[str, object] | None = None,
159    ) -> None:
160        self.mapping: dict[str, object] = mapping or {}
161        self.mapping_trie: dict[str, object] = new_trie(
162            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
163        )
164        self.udf_mapping: dict[str, object] = udf_mapping or {}
165        self.udf_trie: dict[str, object] = new_trie(
166            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
167        )
168
169        self._supported_table_args: tuple[str, ...] = tuple()
170
171    @property
172    def empty(self) -> bool:
173        return not self.mapping
174
175    def depth(self) -> int:
176        return dict_depth(self.mapping)
177
178    def udf_depth(self) -> int:
179        return dict_depth(self.udf_mapping)
180
181    @property
182    def supported_table_args(self) -> tuple[str, ...]:
183        if not self._supported_table_args and self.mapping:
184            depth = self.depth()
185
186            if not depth:  # None
187                self._supported_table_args = tuple()
188            elif 1 <= depth <= 3:
189                self._supported_table_args = exp.TABLE_PARTS[:depth]
190            else:
191                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
192
193        return self._supported_table_args
194
195    def table_parts(self, table: exp.Table) -> list[str]:
196        return [p.name for p in reversed(table.parts)]
197
198    def udf_parts(self, udf: exp.Anonymous) -> list[str]:
199        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
200        parent = udf.parent
201        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
202        return list(reversed(parts))[0 : self.udf_depth()]
203
204    def _find_in_trie(
205        self,
206        parts: list[str],
207        trie: dict[str, object],
208        raise_on_missing: bool,
209    ) -> list[str] | None:
210        value, trie = in_trie(trie, parts)
211
212        if value == TrieResult.FAILED:
213            return None
214
215        if value == TrieResult.PREFIX:
216            possibilities = flatten_schema(trie)
217
218            if len(possibilities) == 1:
219                parts.extend(possibilities[0])
220            else:
221                if raise_on_missing:
222                    joined_parts = ".".join(parts)
223                    message = ", ".join(".".join(p) for p in possibilities)
224                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
225
226                return None
227
228        return parts
229
230    def find(
231        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
232    ) -> t.Any | None:
233        """
234        Returns the schema of a given table.
235
236        Args:
237            table: the target table.
238            raise_on_missing: whether to raise in case the schema is not found.
239            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
240
241        Returns:
242            The schema of the target table.
243        """
244        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
245        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
246
247        if resolved_parts is None:
248            return None
249
250        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
251
252    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None:
253        """
254        Returns the return type of a given UDF.
255
256        Args:
257            udf: the target UDF expression.
258            raise_on_missing: whether to raise if the UDF is not found.
259
260        Returns:
261            The return type of the UDF, or None if not found.
262        """
263        parts = self.udf_parts(udf)
264        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
265
266        if resolved_parts is None:
267            return None
268
269        return nested_get(
270            self.udf_mapping,
271            *zip(resolved_parts, reversed(resolved_parts)),
272            raise_on_missing=raise_on_missing,
273        )
274
275    def nested_get(
276        self,
277        parts: Sequence[str],
278        d: dict[str, object] | None = None,
279        raise_on_missing: bool = True,
280    ) -> t.Any | None:
281        return nested_get(
282            d or self.mapping,
283            *zip(self.supported_table_args, reversed(parts)),
284            raise_on_missing=raise_on_missing,
285        )
AbstractMappingSchema( mapping: dict[str, object] | None = None, udf_mapping: dict[str, object] | None = None)
155    def __init__(
156        self,
157        mapping: dict[str, object] | None = None,
158        udf_mapping: dict[str, object] | None = None,
159    ) -> None:
160        self.mapping: dict[str, object] = mapping or {}
161        self.mapping_trie: dict[str, object] = new_trie(
162            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
163        )
164        self.udf_mapping: dict[str, object] = udf_mapping or {}
165        self.udf_trie: dict[str, object] = new_trie(
166            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
167        )
168
169        self._supported_table_args: tuple[str, ...] = tuple()
mapping: dict[str, object]
mapping_trie: dict[str, object]
udf_mapping: dict[str, object]
udf_trie: dict[str, object]
empty: bool
171    @property
172    def empty(self) -> bool:
173        return not self.mapping
def depth(self) -> int:
175    def depth(self) -> int:
176        return dict_depth(self.mapping)
def udf_depth(self) -> int:
178    def udf_depth(self) -> int:
179        return dict_depth(self.udf_mapping)
supported_table_args: tuple[str, ...]
181    @property
182    def supported_table_args(self) -> tuple[str, ...]:
183        if not self._supported_table_args and self.mapping:
184            depth = self.depth()
185
186            if not depth:  # None
187                self._supported_table_args = tuple()
188            elif 1 <= depth <= 3:
189                self._supported_table_args = exp.TABLE_PARTS[:depth]
190            else:
191                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
192
193        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.query.Table) -> list[str]:
195    def table_parts(self, table: exp.Table) -> list[str]:
196        return [p.name for p in reversed(table.parts)]
def udf_parts(self, udf: sqlglot.expressions.core.Anonymous) -> list[str]:
198    def udf_parts(self, udf: exp.Anonymous) -> list[str]:
199        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
200        parent = udf.parent
201        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
202        return list(reversed(parts))[0 : self.udf_depth()]
def find( self, table: sqlglot.expressions.query.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
230    def find(
231        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
232    ) -> t.Any | None:
233        """
234        Returns the schema of a given table.
235
236        Args:
237            table: the target table.
238            raise_on_missing: whether to raise in case the schema is not found.
239            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
240
241        Returns:
242            The schema of the target table.
243        """
244        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
245        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
246
247        if resolved_parts is None:
248            return None
249
250        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def find_udf( self, udf: sqlglot.expressions.core.Anonymous, raise_on_missing: bool = False) -> Optional[Any]:
252    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Any | None:
253        """
254        Returns the return type of a given UDF.
255
256        Args:
257            udf: the target UDF expression.
258            raise_on_missing: whether to raise if the UDF is not found.
259
260        Returns:
261            The return type of the UDF, or None if not found.
262        """
263        parts = self.udf_parts(udf)
264        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
265
266        if resolved_parts is None:
267            return None
268
269        return nested_get(
270            self.udf_mapping,
271            *zip(resolved_parts, reversed(resolved_parts)),
272            raise_on_missing=raise_on_missing,
273        )

Returns the return type of a given UDF.

Arguments:
  • udf: the target UDF expression.
  • raise_on_missing: whether to raise if the UDF is not found.
Returns:

The return type of the UDF, or None if not found.

def nested_get( self, parts: Sequence[str], d: dict[str, object] | None = None, raise_on_missing: bool = True) -> Optional[Any]:
275    def nested_get(
276        self,
277        parts: Sequence[str],
278        d: dict[str, object] | None = None,
279        raise_on_missing: bool = True,
280    ) -> t.Any | None:
281        return nested_get(
282            d or self.mapping,
283            *zip(self.supported_table_args, reversed(parts)),
284            raise_on_missing=raise_on_missing,
285        )
class MappingSchema(AbstractMappingSchema, Schema):
288class MappingSchema(AbstractMappingSchema, Schema):
289    """
290    Schema based on a nested mapping.
291
292    Args:
293        schema: Mapping in one of the following forms:
294            1. {table: {col: type}}
295            2. {db: {table: {col: type}}}
296            3. {catalog: {db: {table: {col: type}}}}
297            4. None - Tables will be added later
298        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
299            are assumed to be visible. The nesting should mirror that of the schema:
300            1. {table: set(*cols)}}
301            2. {db: {table: set(*cols)}}}
302            3. {catalog: {db: {table: set(*cols)}}}}
303        dialect: The dialect to be used for custom type mappings & parsing string arguments.
304        normalize: Whether to normalize identifier names according to the given dialect or not.
305    """
306
307    def __init__(
308        self,
309        schema: dict[str, object] | None = None,
310        visible: dict[str, object] | None = None,
311        dialect: DialectType = None,
312        normalize: bool = True,
313        udf_mapping: dict[str, object] | None = None,
314    ) -> None:
315        self.visible: dict[str, object] = {} if visible is None else visible
316        self.normalize: bool = normalize
317        self._dialect: Dialect = Dialect.get_or_raise(dialect)
318        self._type_mapping_cache: dict[str, exp.DataType] = {}
319        self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {}
320        self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {}
321        self._find_cache: dict[tuple[exp.Table, bool], dict[str, object] | None] = {}
322        self._depth: int = 0
323        schema = {} if schema is None else schema
324        udf_mapping = {} if udf_mapping is None else udf_mapping
325
326        super().__init__(
327            self._normalize(schema) if self.normalize else schema,
328            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
329        )
330
331    @property
332    def dialect(self) -> Dialect:
333        """Returns the dialect for this mapping schema."""
334        return self._dialect
335
336    @classmethod
337    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
338        return MappingSchema(
339            schema=mapping_schema.mapping,
340            visible=mapping_schema.visible,
341            dialect=mapping_schema.dialect,
342            normalize=mapping_schema.normalize,
343            udf_mapping=mapping_schema.udf_mapping,
344        )
345
346    def find(
347        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
348    ) -> t.Any | None:
349        cache_key = (table, ensure_data_types)
350        schema = self._find_cache.get(cache_key)
351
352        if schema is None:
353            schema = super().find(table, raise_on_missing=raise_on_missing)
354            if ensure_data_types and isinstance(schema, dict):
355                schema = {
356                    col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
357                    for col, dtype in schema.items()
358                }
359            self._find_cache[cache_key] = schema
360
361        return schema
362
363    def copy(
364        self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs]
365    ) -> MappingSchema:
366        mapping_kwargs: SchemaArgs = {
367            "visible": self.visible.copy(),
368            "dialect": self.dialect,
369            "normalize": self.normalize,
370            "udf_mapping": self.udf_mapping.copy(),
371            **kwargs,
372        }
373        return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
374
375    def add_table(
376        self,
377        table: exp.Table | str,
378        column_mapping: ColumnMapping | None = None,
379        dialect: DialectType = None,
380        normalize: bool | None = None,
381        match_depth: bool = True,
382    ) -> None:
383        """
384        Register or update a table. Updates are only performed if a new column mapping is provided.
385        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
386
387        Args:
388            table: the `Table` expression instance or string representing the table.
389            column_mapping: a column mapping that describes the structure of the table.
390            dialect: the SQL dialect that will be used to parse `table` if it's a string.
391            normalize: whether to normalize identifiers according to the dialect of interest.
392            match_depth: whether to enforce that the table must match the schema's depth or not.
393        """
394        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
395
396        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
397            raise SchemaError(
398                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
399                f"schema's nesting level: {self.depth()}."
400            )
401
402        normalized_column_mapping = {
403            self._normalize_name(key, dialect=dialect, normalize=normalize): value
404            for key, value in ensure_column_mapping(column_mapping).items()
405        }
406
407        schema = self.find(normalized_table, raise_on_missing=False)
408        if schema and not normalized_column_mapping:
409            return
410
411        parts = self.table_parts(normalized_table)
412
413        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
414        new_trie([parts], self.mapping_trie)
415        self._find_cache.pop((normalized_table, True), None)
416        self._find_cache.pop((normalized_table, False), None)
417
418    def column_names(
419        self,
420        table: exp.Table | str,
421        only_visible: bool = False,
422        dialect: DialectType = None,
423        normalize: bool | None = None,
424    ) -> list[str]:
425        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
426
427        schema: dict[str, object] | None = self.find(normalized_table)
428        if schema is None:
429            return []
430
431        if not only_visible or not self.visible:
432            return list(schema)
433
434        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
435        return [col for col in schema if col in visible]
436
437    def get_column_type(
438        self,
439        table: exp.Table | str,
440        column: exp.Column | str,
441        dialect: DialectType = None,
442        normalize: bool | None = None,
443    ) -> exp.DataType:
444        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
445
446        normalized_column_name = self._normalize_name(
447            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
448        )
449
450        table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
451        if table_schema:
452            column_type = table_schema.get(normalized_column_name)
453
454            if isinstance(column_type, exp.DataType):
455                return column_type
456            elif isinstance(column_type, str):
457                return self._to_data_type(column_type, dialect=dialect)
458
459        return exp.DType.UNKNOWN.into_expr()
460
461    def get_udf_type(
462        self,
463        udf: exp.Anonymous | str,
464        dialect: DialectType = None,
465        normalize: bool | None = None,
466    ) -> exp.DataType:
467        """
468        Get the return type of a UDF.
469
470        Args:
471            udf: the UDF expression or string (e.g., "db.my_func()").
472            dialect: the SQL dialect for parsing string arguments.
473            normalize: whether to normalize identifiers.
474
475        Returns:
476            The return type as a DataType, or UNKNOWN if not found.
477        """
478        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
479        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
480
481        if resolved_parts is None:
482            return exp.DType.UNKNOWN.into_expr()
483
484        udf_type = nested_get(
485            self.udf_mapping,
486            *zip(resolved_parts, reversed(resolved_parts)),
487            raise_on_missing=False,
488        )
489
490        if isinstance(udf_type, exp.DataType):
491            return udf_type
492        elif isinstance(udf_type, str):
493            return self._to_data_type(udf_type, dialect=dialect)
494
495        return exp.DType.UNKNOWN.into_expr()
496
497    def has_column(
498        self,
499        table: exp.Table | str,
500        column: exp.Column | str,
501        dialect: DialectType = None,
502        normalize: bool | None = None,
503    ) -> bool:
504        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
505
506        normalized_column_name = self._normalize_name(
507            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
508        )
509
510        table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
511        return normalized_column_name in table_schema if table_schema else False
512
513    def _normalize(self, schema: dict[str, object]) -> dict[str, object]:
514        """
515        Normalizes all identifiers in the schema.
516
517        Args:
518            schema: the schema to normalize.
519
520        Returns:
521            The normalized schema mapping.
522        """
523        normalized_mapping: dict[str, object] = {}
524        flattened_schema = flatten_schema(schema)
525        error_msg = "Table {} must match the schema's nesting level: {}."
526
527        for keys in flattened_schema:
528            columns = nested_get(schema, *zip(keys, keys))
529
530            if not isinstance(columns, dict):
531                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
532            if not columns:
533                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
534            if isinstance(first(columns.values()), dict):
535                raise SchemaError(
536                    error_msg.format(
537                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
538                    ),
539                )
540
541            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
542            for column_name, column_type in columns.items():
543                nested_set(
544                    normalized_mapping,
545                    normalized_keys + [self._normalize_name(column_name)],
546                    column_type,
547                )
548
549        return normalized_mapping
550
551    def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]:
552        """
553        Normalizes all identifiers in the UDF mapping.
554
555        Args:
556            udfs: the UDF mapping to normalize.
557
558        Returns:
559            The normalized UDF mapping.
560        """
561        normalized_mapping: dict[str, object] = {}
562
563        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
564            udf_type = nested_get(udfs, *zip(keys, keys))
565            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
566            nested_set(normalized_mapping, normalized_keys, udf_type)
567
568        return normalized_mapping
569
570    def _normalize_udf(
571        self,
572        udf: exp.Anonymous | str,
573        dialect: DialectType = None,
574        normalize: bool | None = None,
575    ) -> list[str]:
576        """
577        Extract and normalize UDF parts for lookup.
578
579        Args:
580            udf: the UDF expression or qualified string (e.g., "db.my_func()").
581            dialect: the SQL dialect for parsing.
582            normalize: whether to normalize identifiers.
583
584        Returns:
585            A list of normalized UDF parts (reversed for trie lookup).
586        """
587        dialect = dialect or self.dialect
588        normalize = self.normalize if normalize is None else normalize
589
590        if isinstance(udf, str):
591            parsed: exp.Expr = exp.maybe_parse(udf, dialect=dialect)
592
593            if isinstance(parsed, exp.Anonymous):
594                udf = parsed
595            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
596                udf = parsed.expression
597            else:
598                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
599        parts = self.udf_parts(udf)
600
601        if normalize:
602            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
603
604        return parts
605
606    def _normalize_table(
607        self,
608        table: exp.Table | str,
609        dialect: DialectType = None,
610        normalize: bool | None = None,
611    ) -> exp.Table:
612        dialect = dialect or self.dialect
613        normalize = self.normalize if normalize is None else normalize
614
615        # Cache normalized tables by object id for exp.Table inputs
616        # This is effective when the same Table object is looked up multiple times
617        if isinstance(table, exp.Table) and (
618            cached := self._normalized_table_cache.get((table, dialect, normalize))
619        ):
620            return cached
621
622        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
623
624        if normalize:
625            for part in normalized_table.parts:
626                if isinstance(part, exp.Identifier):
627                    part.replace(
628                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
629                    )
630
631        self._normalized_table_cache[(normalized_table, dialect, normalize)] = normalized_table
632        return normalized_table
633
634    def _normalize_name(
635        self,
636        name: str | exp.Identifier,
637        dialect: DialectType = None,
638        is_table: bool = False,
639        normalize: bool | None = None,
640    ) -> str:
641        normalize = self.normalize if normalize is None else normalize
642
643        dialect = dialect or self.dialect
644        name_str = name if isinstance(name, str) else name.name
645        cache_key = (name_str, dialect, is_table, normalize)
646
647        if cached := self._normalized_name_cache.get(cache_key):
648            return cached
649
650        result = normalize_name(
651            name,
652            dialect=dialect,
653            is_table=is_table,
654            normalize=normalize,
655        ).name
656
657        self._normalized_name_cache[cache_key] = result
658        return result
659
660    def depth(self) -> int:
661        if not self.empty and not self._depth:
662            # The columns themselves are a mapping, but we don't want to include those
663            self._depth = super().depth() - 1
664        return self._depth
665
666    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
667        """
668        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
669
670        Args:
671            schema_type: the type we want to convert.
672            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
673
674        Returns:
675            The resulting expression type.
676        """
677        if schema_type not in self._type_mapping_cache:
678            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
679            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
680
681            try:
682                expression = exp.DataType.from_str(schema_type, dialect=dialect, udt=udt)
683                expression.transform(dialect.normalize_identifier, copy=False)
684                self._type_mapping_cache[schema_type] = expression
685            except AttributeError:
686                in_dialect = f" in dialect {dialect}" if dialect else ""
687                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
688
689        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(*cols)}}
    2. {db: {table: set(*cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: dict[str, object] | None = None, visible: dict[str, object] | None = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool = True, udf_mapping: dict[str, object] | None = None)
307    def __init__(
308        self,
309        schema: dict[str, object] | None = None,
310        visible: dict[str, object] | None = None,
311        dialect: DialectType = None,
312        normalize: bool = True,
313        udf_mapping: dict[str, object] | None = None,
314    ) -> None:
315        self.visible: dict[str, object] = {} if visible is None else visible
316        self.normalize: bool = normalize
317        self._dialect: Dialect = Dialect.get_or_raise(dialect)
318        self._type_mapping_cache: dict[str, exp.DataType] = {}
319        self._normalized_table_cache: dict[tuple[exp.Table, DialectType, bool], exp.Table] = {}
320        self._normalized_name_cache: dict[tuple[str, DialectType, bool, bool], str] = {}
321        self._find_cache: dict[tuple[exp.Table, bool], dict[str, object] | None] = {}
322        self._depth: int = 0
323        schema = {} if schema is None else schema
324        udf_mapping = {} if udf_mapping is None else udf_mapping
325
326        super().__init__(
327            self._normalize(schema) if self.normalize else schema,
328            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
329        )
visible: dict[str, object]
normalize: bool
dialect: sqlglot.dialects.Dialect
331    @property
332    def dialect(self) -> Dialect:
333        """Returns the dialect for this mapping schema."""
334        return self._dialect

Returns the dialect for this mapping schema.

@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
336    @classmethod
337    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
338        return MappingSchema(
339            schema=mapping_schema.mapping,
340            visible=mapping_schema.visible,
341            dialect=mapping_schema.dialect,
342            normalize=mapping_schema.normalize,
343            udf_mapping=mapping_schema.udf_mapping,
344        )
def find( self, table: sqlglot.expressions.query.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
346    def find(
347        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
348    ) -> t.Any | None:
349        cache_key = (table, ensure_data_types)
350        schema = self._find_cache.get(cache_key)
351
352        if schema is None:
353            schema = super().find(table, raise_on_missing=raise_on_missing)
354            if ensure_data_types and isinstance(schema, dict):
355                schema = {
356                    col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
357                    for col, dtype in schema.items()
358                }
359            self._find_cache[cache_key] = schema
360
361        return schema

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def copy( self, schema: dict[str, object] | None = None, **kwargs: typing_extensions.Unpack[sqlglot._typing.SchemaArgs]) -> MappingSchema:
363    def copy(
364        self, schema: dict[str, object] | None = None, **kwargs: Unpack[SchemaArgs]
365    ) -> MappingSchema:
366        mapping_kwargs: SchemaArgs = {
367            "visible": self.visible.copy(),
368            "dialect": self.dialect,
369            "normalize": self.normalize,
370            "udf_mapping": self.udf_mapping.copy(),
371            **kwargs,
372        }
373        return MappingSchema(self.mapping.copy() if schema is None else schema, **mapping_kwargs)
def add_table( self, table: sqlglot.expressions.query.Table | str, column_mapping: Union[dict[str, Any], str, list[str], NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None, match_depth: bool = True) -> None:
375    def add_table(
376        self,
377        table: exp.Table | str,
378        column_mapping: ColumnMapping | None = None,
379        dialect: DialectType = None,
380        normalize: bool | None = None,
381        match_depth: bool = True,
382    ) -> None:
383        """
384        Register or update a table. Updates are only performed if a new column mapping is provided.
385        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
386
387        Args:
388            table: the `Table` expression instance or string representing the table.
389            column_mapping: a column mapping that describes the structure of the table.
390            dialect: the SQL dialect that will be used to parse `table` if it's a string.
391            normalize: whether to normalize identifiers according to the dialect of interest.
392            match_depth: whether to enforce that the table must match the schema's depth or not.
393        """
394        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
395
396        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
397            raise SchemaError(
398                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
399                f"schema's nesting level: {self.depth()}."
400            )
401
402        normalized_column_mapping = {
403            self._normalize_name(key, dialect=dialect, normalize=normalize): value
404            for key, value in ensure_column_mapping(column_mapping).items()
405        }
406
407        schema = self.find(normalized_table, raise_on_missing=False)
408        if schema and not normalized_column_mapping:
409            return
410
411        parts = self.table_parts(normalized_table)
412
413        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
414        new_trie([parts], self.mapping_trie)
415        self._find_cache.pop((normalized_table, True), None)
416        self._find_cache.pop((normalized_table, False), None)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.query.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> list[str]:
418    def column_names(
419        self,
420        table: exp.Table | str,
421        only_visible: bool = False,
422        dialect: DialectType = None,
423        normalize: bool | None = None,
424    ) -> list[str]:
425        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
426
427        schema: dict[str, object] | None = self.find(normalized_table)
428        if schema is None:
429            return []
430
431        if not only_visible or not self.visible:
432            return list(schema)
433
434        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
435        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
437    def get_column_type(
438        self,
439        table: exp.Table | str,
440        column: exp.Column | str,
441        dialect: DialectType = None,
442        normalize: bool | None = None,
443    ) -> exp.DataType:
444        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
445
446        normalized_column_name = self._normalize_name(
447            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
448        )
449
450        table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
451        if table_schema:
452            column_type = table_schema.get(normalized_column_name)
453
454            if isinstance(column_type, exp.DataType):
455                return column_type
456            elif isinstance(column_type, str):
457                return self._to_data_type(column_type, dialect=dialect)
458
459        return exp.DType.UNKNOWN.into_expr()

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def get_udf_type( self, udf: sqlglot.expressions.core.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> sqlglot.expressions.datatypes.DataType:
461    def get_udf_type(
462        self,
463        udf: exp.Anonymous | str,
464        dialect: DialectType = None,
465        normalize: bool | None = None,
466    ) -> exp.DataType:
467        """
468        Get the return type of a UDF.
469
470        Args:
471            udf: the UDF expression or string (e.g., "db.my_func()").
472            dialect: the SQL dialect for parsing string arguments.
473            normalize: whether to normalize identifiers.
474
475        Returns:
476            The return type as a DataType, or UNKNOWN if not found.
477        """
478        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
479        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
480
481        if resolved_parts is None:
482            return exp.DType.UNKNOWN.into_expr()
483
484        udf_type = nested_get(
485            self.udf_mapping,
486            *zip(resolved_parts, reversed(resolved_parts)),
487            raise_on_missing=False,
488        )
489
490        if isinstance(udf_type, exp.DataType):
491            return udf_type
492        elif isinstance(udf_type, str):
493            return self._to_data_type(udf_type, dialect=dialect)
494
495        return exp.DType.UNKNOWN.into_expr()

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string (e.g., "db.my_func()").
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

def has_column( self, table: sqlglot.expressions.query.Table | str, column: sqlglot.expressions.core.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool | None = None) -> bool:
497    def has_column(
498        self,
499        table: exp.Table | str,
500        column: exp.Column | str,
501        dialect: DialectType = None,
502        normalize: bool | None = None,
503    ) -> bool:
504        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
505
506        normalized_column_name = self._normalize_name(
507            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
508        )
509
510        table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
511        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
660    def depth(self) -> int:
661        if not self.empty and not self._depth:
662            # The columns themselves are a mapping, but we don't want to include those
663            self._depth = super().depth() - 1
664        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.core.Identifier, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, is_table: bool = False, normalize: bool | None = True) -> sqlglot.expressions.core.Identifier:
692def normalize_name(
693    identifier: str | exp.Identifier,
694    dialect: DialectType = None,
695    is_table: bool = False,
696    normalize: bool | None = True,
697) -> exp.Identifier:
698    if isinstance(identifier, str):
699        identifier = exp.parse_identifier(identifier, dialect=dialect)
700
701    if not normalize:
702        return identifier
703
704    # this is used for normalize_identifier, bigquery has special rules pertaining tables
705    identifier.meta["is_table"] = is_table
706    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Schema | dict[str, object] | None, **kwargs: typing_extensions.Unpack[sqlglot._typing.SchemaArgs]) -> Schema:
709def ensure_schema(
710    schema: Schema | dict[str, object] | None, **kwargs: Unpack[SchemaArgs]
711) -> Schema:
712    if isinstance(schema, Schema):
713        return schema
714
715    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[dict[str, Any], str, list[str], NoneType]) -> dict[str, typing.Any]:
718def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, t.Any]:
719    if mapping is None:
720        return {}
721    elif isinstance(mapping, dict):
722        return mapping
723    elif isinstance(mapping, str):
724        col_name_type_strs = [x.strip() for x in mapping.split(",")]
725        return {
726            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
727            for name_type_str in col_name_type_strs
728        }
729    elif isinstance(mapping, list):
730        return {x.strip(): None for x in mapping}
731
732    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None) -> list[list[str]]:
735def flatten_schema(
736    schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None
737) -> list[list[str]]:
738    tables: list[list[str]] = []
739    keys = keys or []
740    depth = dict_depth(schema) - 1 if depth is None else depth
741
742    for k, v in schema.items():
743        if depth == 1 or not isinstance(v, dict):
744            tables.append(keys + [k])
745        elif depth >= 2:
746            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
747
748    return tables
def nested_get( d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
751def nested_get(
752    d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True
753) -> t.Any | None:
754    """
755    Get a value for a nested dictionary.
756
757    Args:
758        d: the dictionary to search.
759        *path: tuples of (name, key), where:
760            `key` is the key in the dictionary to get.
761            `name` is a string to use in the error if `key` isn't found.
762
763    Returns:
764        The value or None if it doesn't exist.
765    """
766    result: t.Any = d
767    for name, key in path:
768        result = result.get(key)
769        if result is None:
770            if raise_on_missing:
771                name = "table" if name == "this" else name
772                raise ValueError(f"Unknown {name}: {key}")
773            return None
774
775    return result

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set( d: dict[str, typing.Any], keys: Sequence[str], value: Any) -> dict[str, typing.Any]:
778def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]:
779    """
780    In-place set a value for a nested dictionary
781
782    Example:
783        >>> nested_set({}, ["top_key", "second_key"], "value")
784        {'top_key': {'second_key': 'value'}}
785
786        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
787        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
788
789    Args:
790        d: dictionary to update.
791        keys: the keys that makeup the path to `value`.
792        value: the value to set in the dictionary for the given key path.
793
794    Returns:
795        The (possibly) updated dictionary.
796    """
797    if not keys:
798        return d
799
800    if len(keys) == 1:
801        d[keys[0]] = value
802        return d
803
804    subd = d
805    for key in keys[:-1]:
806        if key not in subd:
807            subd = subd.setdefault(key, {})
808        else:
809            subd = subd[key]
810
811    subd[keys[-1]] = value
812    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.