Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12from sqlglot.helper import mypyc_attr, trait
 13
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot.dialects.dialect import DialectType
 17
 18    ColumnMapping = t.Union[t.Dict, str, t.List]
 19
 20
 21@trait
 22class Schema(abc.ABC):
 23    """Abstract base class for database schemas"""
 24
 25    @property
 26    def dialect(self) -> t.Optional[Dialect]:
 27        """
 28        Returns None by default. Subclasses that require dialect-specific
 29        behavior should override this property.
 30        """
 31        return None
 32
 33    @abc.abstractmethod
 34    def add_table(
 35        self,
 36        table: exp.Table | str,
 37        column_mapping: t.Optional[ColumnMapping] = None,
 38        dialect: DialectType = None,
 39        normalize: t.Optional[bool] = None,
 40        match_depth: bool = True,
 41    ) -> None:
 42        """
 43        Register or update a table. Some implementing classes may require column information to also be provided.
 44        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 45
 46        Args:
 47            table: the `Table` expression instance or string representing the table.
 48            column_mapping: a column mapping that describes the structure of the table.
 49            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 50            normalize: whether to normalize identifiers according to the dialect of interest.
 51            match_depth: whether to enforce that the table must match the schema's depth or not.
 52        """
 53
 54    @abc.abstractmethod
 55    def column_names(
 56        self,
 57        table: exp.Table | str,
 58        only_visible: bool = False,
 59        dialect: DialectType = None,
 60        normalize: t.Optional[bool] = None,
 61    ) -> t.Sequence[str]:
 62        """
 63        Get the column names for a table.
 64
 65        Args:
 66            table: the `Table` expression instance.
 67            only_visible: whether to include invisible columns.
 68            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 69            normalize: whether to normalize identifiers according to the dialect of interest.
 70
 71        Returns:
 72            The sequence of column names.
 73        """
 74
 75    @abc.abstractmethod
 76    def get_column_type(
 77        self,
 78        table: exp.Table | str,
 79        column: exp.Column | str,
 80        dialect: DialectType = None,
 81        normalize: t.Optional[bool] = None,
 82    ) -> exp.DataType:
 83        """
 84        Get the `sqlglot.exp.DataType` type of a column in the schema.
 85
 86        Args:
 87            table: the source table.
 88            column: the target column.
 89            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 90            normalize: whether to normalize identifiers according to the dialect of interest.
 91
 92        Returns:
 93            The resulting column type.
 94        """
 95
 96    def has_column(
 97        self,
 98        table: exp.Table | str,
 99        column: exp.Column | str,
100        dialect: DialectType = None,
101        normalize: t.Optional[bool] = None,
102    ) -> bool:
103        """
104        Returns whether `column` appears in `table`'s schema.
105
106        Args:
107            table: the source table.
108            column: the target column.
109            dialect: the SQL dialect that will be used to parse `table` if it's a string.
110            normalize: whether to normalize identifiers according to the dialect of interest.
111
112        Returns:
113            True if the column appears in the schema, False otherwise.
114        """
115        name = column if isinstance(column, str) else column.name
116        return name in self.column_names(table, dialect=dialect, normalize=normalize)
117
118    def get_udf_type(
119        self,
120        udf: exp.Anonymous | str,
121        dialect: DialectType = None,
122        normalize: t.Optional[bool] = None,
123    ) -> exp.DataType:
124        """
125        Get the return type of a UDF.
126
127        Args:
128            udf: the UDF expression or string.
129            dialect: the SQL dialect for parsing string arguments.
130            normalize: whether to normalize identifiers.
131
132        Returns:
133            The return type as a DataType, or UNKNOWN if not found.
134        """
135        return exp.DataType.build("unknown")
136
137    @property
138    @abc.abstractmethod
139    def supported_table_args(self) -> t.Tuple[str, ...]:
140        """
141        Table arguments this schema support, e.g. `("this", "db", "catalog")`
142        """
143
144    @property
145    def empty(self) -> bool:
146        """Returns whether the schema is empty."""
147        return True
148
149
150@mypyc_attr(allow_interpreted_subclasses=True)
151class AbstractMappingSchema:
152    def __init__(
153        self,
154        mapping: t.Optional[t.Dict] = None,
155        udf_mapping: t.Optional[t.Dict] = None,
156    ) -> None:
157        self.mapping = mapping or {}
158        self.mapping_trie = new_trie(
159            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
160        )
161
162        self.udf_mapping = udf_mapping or {}
163        self.udf_trie = new_trie(
164            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
165        )
166
167        self._supported_table_args: t.Tuple[str, ...] = tuple()
168
169    @property
170    def empty(self) -> bool:
171        return not self.mapping
172
173    def depth(self) -> int:
174        return dict_depth(self.mapping)
175
176    def udf_depth(self) -> int:
177        return dict_depth(self.udf_mapping)
178
179    @property
180    def supported_table_args(self) -> t.Tuple[str, ...]:
181        if not self._supported_table_args and self.mapping:
182            depth = self.depth()
183
184            if not depth:  # None
185                self._supported_table_args = tuple()
186            elif 1 <= depth <= 3:
187                self._supported_table_args = exp.TABLE_PARTS[:depth]
188            else:
189                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
190
191        return self._supported_table_args
192
193    def table_parts(self, table: exp.Table) -> t.List[str]:
194        return [p.name for p in reversed(table.parts)]
195
196    def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
197        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
198        parent = udf.parent
199        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
200        return list(reversed(parts))[0 : self.udf_depth()]
201
202    def _find_in_trie(
203        self,
204        parts: t.List[str],
205        trie: t.Dict,
206        raise_on_missing: bool,
207    ) -> t.Optional[t.List[str]]:
208        value, trie = in_trie(trie, parts)
209
210        if value == TrieResult.FAILED:
211            return None
212
213        if value == TrieResult.PREFIX:
214            possibilities = flatten_schema(trie)
215
216            if len(possibilities) == 1:
217                parts.extend(possibilities[0])
218            else:
219                if raise_on_missing:
220                    joined_parts = ".".join(parts)
221                    message = ", ".join(".".join(p) for p in possibilities)
222                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
223
224                return None
225
226        return parts
227
228    def find(
229        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
230    ) -> t.Optional[t.Any]:
231        """
232        Returns the schema of a given table.
233
234        Args:
235            table: the target table.
236            raise_on_missing: whether to raise in case the schema is not found.
237            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
238
239        Returns:
240            The schema of the target table.
241        """
242        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
243        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
244
245        if resolved_parts is None:
246            return None
247
248        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
249
250    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
251        """
252        Returns the return type of a given UDF.
253
254        Args:
255            udf: the target UDF expression.
256            raise_on_missing: whether to raise if the UDF is not found.
257
258        Returns:
259            The return type of the UDF, or None if not found.
260        """
261        parts = self.udf_parts(udf)
262        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
263
264        if resolved_parts is None:
265            return None
266
267        return nested_get(
268            self.udf_mapping,
269            *zip(resolved_parts, reversed(resolved_parts)),
270            raise_on_missing=raise_on_missing,
271        )
272
273    def nested_get(
274        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
275    ) -> t.Optional[t.Any]:
276        return nested_get(
277            d or self.mapping,
278            *zip(self.supported_table_args, reversed(parts)),
279            raise_on_missing=raise_on_missing,
280        )
281
282
283class MappingSchema(AbstractMappingSchema, Schema):
284    """
285    Schema based on a nested mapping.
286
287    Args:
288        schema: Mapping in one of the following forms:
289            1. {table: {col: type}}
290            2. {db: {table: {col: type}}}
291            3. {catalog: {db: {table: {col: type}}}}
292            4. None - Tables will be added later
293        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
294            are assumed to be visible. The nesting should mirror that of the schema:
295            1. {table: set(*cols)}}
296            2. {db: {table: set(*cols)}}}
297            3. {catalog: {db: {table: set(*cols)}}}}
298        dialect: The dialect to be used for custom type mappings & parsing string arguments.
299        normalize: Whether to normalize identifier names according to the given dialect or not.
300    """
301
302    def __init__(
303        self,
304        schema: t.Optional[t.Dict] = None,
305        visible: t.Optional[t.Dict] = None,
306        dialect: DialectType = None,
307        normalize: bool = True,
308        udf_mapping: t.Optional[t.Dict] = None,
309    ) -> None:
310        self.visible = {} if visible is None else visible
311        self.normalize = normalize
312        self._dialect = Dialect.get_or_raise(dialect)
313        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
314        self._normalized_table_cache: t.Dict[t.Tuple[exp.Table, DialectType, bool], exp.Table] = {}
315        self._normalized_name_cache: t.Dict[t.Tuple[str, DialectType, bool, bool], str] = {}
316        self._depth = 0
317        schema = {} if schema is None else schema
318        udf_mapping = {} if udf_mapping is None else udf_mapping
319
320        super().__init__(
321            self._normalize(schema) if self.normalize else schema,
322            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
323        )
324
325    @property
326    def dialect(self) -> Dialect:
327        """Returns the dialect for this mapping schema."""
328        return self._dialect
329
330    @classmethod
331    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
332        return MappingSchema(
333            schema=mapping_schema.mapping,
334            visible=mapping_schema.visible,
335            dialect=mapping_schema.dialect,
336            normalize=mapping_schema.normalize,
337            udf_mapping=mapping_schema.udf_mapping,
338        )
339
340    def find(
341        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
342    ) -> t.Optional[t.Any]:
343        schema = super().find(
344            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
345        )
346        if ensure_data_types and isinstance(schema, dict):
347            schema = {
348                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
349                for col, dtype in schema.items()
350            }
351
352        return schema
353
354    def copy(self, **kwargs) -> MappingSchema:
355        return MappingSchema(
356            **{  # type: ignore
357                "schema": self.mapping.copy(),
358                "visible": self.visible.copy(),
359                "dialect": self.dialect,
360                "normalize": self.normalize,
361                "udf_mapping": self.udf_mapping.copy(),
362                **kwargs,
363            }
364        )
365
366    def add_table(
367        self,
368        table: exp.Table | str,
369        column_mapping: t.Optional[ColumnMapping] = None,
370        dialect: DialectType = None,
371        normalize: t.Optional[bool] = None,
372        match_depth: bool = True,
373    ) -> None:
374        """
375        Register or update a table. Updates are only performed if a new column mapping is provided.
376        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
377
378        Args:
379            table: the `Table` expression instance or string representing the table.
380            column_mapping: a column mapping that describes the structure of the table.
381            dialect: the SQL dialect that will be used to parse `table` if it's a string.
382            normalize: whether to normalize identifiers according to the dialect of interest.
383            match_depth: whether to enforce that the table must match the schema's depth or not.
384        """
385        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
386
387        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
388            raise SchemaError(
389                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
390                f"schema's nesting level: {self.depth()}."
391            )
392
393        normalized_column_mapping = {
394            self._normalize_name(key, dialect=dialect, normalize=normalize): value
395            for key, value in ensure_column_mapping(column_mapping).items()
396        }
397
398        schema = self.find(normalized_table, raise_on_missing=False)
399        if schema and not normalized_column_mapping:
400            return
401
402        parts = self.table_parts(normalized_table)
403
404        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
405        new_trie([parts], self.mapping_trie)
406
407    def column_names(
408        self,
409        table: exp.Table | str,
410        only_visible: bool = False,
411        dialect: DialectType = None,
412        normalize: t.Optional[bool] = None,
413    ) -> t.List[str]:
414        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
415
416        schema = self.find(normalized_table)
417        if schema is None:
418            return []
419
420        if not only_visible or not self.visible:
421            return list(schema)
422
423        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
424        return [col for col in schema if col in visible]
425
426    def get_column_type(
427        self,
428        table: exp.Table | str,
429        column: exp.Column | str,
430        dialect: DialectType = None,
431        normalize: t.Optional[bool] = None,
432    ) -> exp.DataType:
433        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
434
435        normalized_column_name = self._normalize_name(
436            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
437        )
438
439        table_schema = self.find(normalized_table, raise_on_missing=False)
440        if table_schema:
441            column_type = table_schema.get(normalized_column_name)
442
443            if isinstance(column_type, exp.DataType):
444                return column_type
445            elif isinstance(column_type, str):
446                return self._to_data_type(column_type, dialect=dialect)
447
448        return exp.DataType.build("unknown")
449
450    def get_udf_type(
451        self,
452        udf: exp.Anonymous | str,
453        dialect: DialectType = None,
454        normalize: t.Optional[bool] = None,
455    ) -> exp.DataType:
456        """
457        Get the return type of a UDF.
458
459        Args:
460            udf: the UDF expression or string (e.g., "db.my_func()").
461            dialect: the SQL dialect for parsing string arguments.
462            normalize: whether to normalize identifiers.
463
464        Returns:
465            The return type as a DataType, or UNKNOWN if not found.
466        """
467        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
468        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
469
470        if resolved_parts is None:
471            return exp.DataType.build("unknown")
472
473        udf_type = nested_get(
474            self.udf_mapping,
475            *zip(resolved_parts, reversed(resolved_parts)),
476            raise_on_missing=False,
477        )
478
479        if isinstance(udf_type, exp.DataType):
480            return udf_type
481        elif isinstance(udf_type, str):
482            return self._to_data_type(udf_type, dialect=dialect)
483
484        return exp.DataType.build("unknown")
485
486    def has_column(
487        self,
488        table: exp.Table | str,
489        column: exp.Column | str,
490        dialect: DialectType = None,
491        normalize: t.Optional[bool] = None,
492    ) -> bool:
493        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
494
495        normalized_column_name = self._normalize_name(
496            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
497        )
498
499        table_schema = self.find(normalized_table, raise_on_missing=False)
500        return normalized_column_name in table_schema if table_schema else False
501
502    def _normalize(self, schema: t.Dict) -> t.Dict:
503        """
504        Normalizes all identifiers in the schema.
505
506        Args:
507            schema: the schema to normalize.
508
509        Returns:
510            The normalized schema mapping.
511        """
512        normalized_mapping: t.Dict = {}
513        flattened_schema = flatten_schema(schema)
514        error_msg = "Table {} must match the schema's nesting level: {}."
515
516        for keys in flattened_schema:
517            columns = nested_get(schema, *zip(keys, keys))
518
519            if not isinstance(columns, dict):
520                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
521            if not columns:
522                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
523            if isinstance(first(columns.values()), dict):
524                raise SchemaError(
525                    error_msg.format(
526                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
527                    ),
528                )
529
530            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
531            for column_name, column_type in columns.items():
532                nested_set(
533                    normalized_mapping,
534                    normalized_keys + [self._normalize_name(column_name)],
535                    column_type,
536                )
537
538        return normalized_mapping
539
540    def _normalize_udfs(self, udfs: t.Dict) -> t.Dict:
541        """
542        Normalizes all identifiers in the UDF mapping.
543
544        Args:
545            udfs: the UDF mapping to normalize.
546
547        Returns:
548            The normalized UDF mapping.
549        """
550        normalized_mapping: t.Dict = {}
551
552        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
553            udf_type = nested_get(udfs, *zip(keys, keys))
554            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
555            nested_set(normalized_mapping, normalized_keys, udf_type)
556
557        return normalized_mapping
558
559    def _normalize_udf(
560        self,
561        udf: exp.Anonymous | str,
562        dialect: DialectType = None,
563        normalize: t.Optional[bool] = None,
564    ) -> t.List[str]:
565        """
566        Extract and normalize UDF parts for lookup.
567
568        Args:
569            udf: the UDF expression or qualified string (e.g., "db.my_func()").
570            dialect: the SQL dialect for parsing.
571            normalize: whether to normalize identifiers.
572
573        Returns:
574            A list of normalized UDF parts (reversed for trie lookup).
575        """
576        dialect = dialect or self.dialect
577        normalize = self.normalize if normalize is None else normalize
578
579        if isinstance(udf, str):
580            parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect)
581
582            if isinstance(parsed, exp.Anonymous):
583                udf = parsed
584            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
585                udf = parsed.expression
586            else:
587                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
588        parts = self.udf_parts(udf)
589
590        if normalize:
591            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
592
593        return parts
594
595    def _normalize_table(
596        self,
597        table: exp.Table | str,
598        dialect: DialectType = None,
599        normalize: t.Optional[bool] = None,
600    ) -> exp.Table:
601        dialect = dialect or self.dialect
602        normalize = self.normalize if normalize is None else normalize
603
604        # Cache normalized tables by object id for exp.Table inputs
605        # This is effective when the same Table object is looked up multiple times
606        if isinstance(table, exp.Table) and (
607            cached := self._normalized_table_cache.get((table, dialect, normalize))
608        ):
609            return cached
610
611        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
612
613        if normalize:
614            for part in normalized_table.parts:
615                if isinstance(part, exp.Identifier):
616                    part.replace(
617                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
618                    )
619
620        self._normalized_table_cache[(t.cast(exp.Table, table), dialect, normalize)] = (
621            normalized_table
622        )
623        return normalized_table
624
625    def _normalize_name(
626        self,
627        name: str | exp.Identifier,
628        dialect: DialectType = None,
629        is_table: bool = False,
630        normalize: t.Optional[bool] = None,
631    ) -> str:
632        normalize = self.normalize if normalize is None else normalize
633
634        dialect = dialect or self.dialect
635        name_str = name if isinstance(name, str) else name.name
636        cache_key = (name_str, dialect, is_table, normalize)
637
638        if cached := self._normalized_name_cache.get(cache_key):
639            return cached
640
641        result = normalize_name(
642            name,
643            dialect=dialect,
644            is_table=is_table,
645            normalize=normalize,
646        ).name
647
648        self._normalized_name_cache[cache_key] = result
649        return result
650
651    def depth(self) -> int:
652        if not self.empty and not self._depth:
653            # The columns themselves are a mapping, but we don't want to include those
654            self._depth = super().depth() - 1
655        return self._depth
656
657    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
658        """
659        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
660
661        Args:
662            schema_type: the type we want to convert.
663            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
664
665        Returns:
666            The resulting expression type.
667        """
668        if schema_type not in self._type_mapping_cache:
669            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
670            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
671
672            try:
673                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
674                expression.transform(dialect.normalize_identifier, copy=False)
675                self._type_mapping_cache[schema_type] = expression
676            except AttributeError:
677                in_dialect = f" in dialect {dialect}" if dialect else ""
678                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
679
680        return self._type_mapping_cache[schema_type]
681
682
683def normalize_name(
684    identifier: str | exp.Identifier,
685    dialect: DialectType = None,
686    is_table: bool = False,
687    normalize: t.Optional[bool] = True,
688) -> exp.Identifier:
689    if isinstance(identifier, str):
690        identifier = exp.parse_identifier(identifier, dialect=dialect)
691
692    if not normalize:
693        return identifier
694
695    # this is used for normalize_identifier, bigquery has special rules pertaining tables
696    identifier.meta["is_table"] = is_table
697    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
698
699
700def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
701    if isinstance(schema, Schema):
702        return schema
703
704    return MappingSchema(schema, **kwargs)
705
706
707def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
708    if mapping is None:
709        return {}
710    elif isinstance(mapping, dict):
711        return mapping
712    elif isinstance(mapping, str):
713        col_name_type_strs = [x.strip() for x in mapping.split(",")]
714        return {
715            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
716            for name_type_str in col_name_type_strs
717        }
718    elif isinstance(mapping, list):
719        return {x.strip(): None for x in mapping}
720
721    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
722
723
724def flatten_schema(
725    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
726) -> t.List[t.List[str]]:
727    tables = []
728    keys = keys or []
729    depth = dict_depth(schema) - 1 if depth is None else depth
730
731    for k, v in schema.items():
732        if depth == 1 or not isinstance(v, dict):
733            tables.append(keys + [k])
734        elif depth >= 2:
735            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
736
737    return tables
738
739
740def nested_get(
741    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
742) -> t.Optional[t.Any]:
743    """
744    Get a value for a nested dictionary.
745
746    Args:
747        d: the dictionary to search.
748        *path: tuples of (name, key), where:
749            `key` is the key in the dictionary to get.
750            `name` is a string to use in the error if `key` isn't found.
751
752    Returns:
753        The value or None if it doesn't exist.
754    """
755    result: t.Any = d
756    for name, key in path:
757        result = result.get(key)
758        if result is None:
759            if raise_on_missing:
760                name = "table" if name == "this" else name
761                raise ValueError(f"Unknown {name}: {key}")
762            return None
763
764    return result
765
766
767def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
768    """
769    In-place set a value for a nested dictionary
770
771    Example:
772        >>> nested_set({}, ["top_key", "second_key"], "value")
773        {'top_key': {'second_key': 'value'}}
774
775        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
776        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
777
778    Args:
779        d: dictionary to update.
780        keys: the keys that makeup the path to `value`.
781        value: the value to set in the dictionary for the given key path.
782
783    Returns:
784        The (possibly) updated dictionary.
785    """
786    if not keys:
787        return d
788
789    if len(keys) == 1:
790        d[keys[0]] = value
791        return d
792
793    subd = d
794    for key in keys[:-1]:
795        if key not in subd:
796            subd = subd.setdefault(key, {})
797        else:
798            subd = subd[key]
799
800    subd[keys[-1]] = value
801    return d
@trait
class Schema(abc.ABC):
 22@trait
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    @property
 27    def dialect(self) -> t.Optional[Dialect]:
 28        """
 29        Returns None by default. Subclasses that require dialect-specific
 30        behavior should override this property.
 31        """
 32        return None
 33
 34    @abc.abstractmethod
 35    def add_table(
 36        self,
 37        table: exp.Table | str,
 38        column_mapping: t.Optional[ColumnMapping] = None,
 39        dialect: DialectType = None,
 40        normalize: t.Optional[bool] = None,
 41        match_depth: bool = True,
 42    ) -> None:
 43        """
 44        Register or update a table. Some implementing classes may require column information to also be provided.
 45        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 46
 47        Args:
 48            table: the `Table` expression instance or string representing the table.
 49            column_mapping: a column mapping that describes the structure of the table.
 50            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 51            normalize: whether to normalize identifiers according to the dialect of interest.
 52            match_depth: whether to enforce that the table must match the schema's depth or not.
 53        """
 54
 55    @abc.abstractmethod
 56    def column_names(
 57        self,
 58        table: exp.Table | str,
 59        only_visible: bool = False,
 60        dialect: DialectType = None,
 61        normalize: t.Optional[bool] = None,
 62    ) -> t.Sequence[str]:
 63        """
 64        Get the column names for a table.
 65
 66        Args:
 67            table: the `Table` expression instance.
 68            only_visible: whether to include invisible columns.
 69            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 70            normalize: whether to normalize identifiers according to the dialect of interest.
 71
 72        Returns:
 73            The sequence of column names.
 74        """
 75
 76    @abc.abstractmethod
 77    def get_column_type(
 78        self,
 79        table: exp.Table | str,
 80        column: exp.Column | str,
 81        dialect: DialectType = None,
 82        normalize: t.Optional[bool] = None,
 83    ) -> exp.DataType:
 84        """
 85        Get the `sqlglot.exp.DataType` type of a column in the schema.
 86
 87        Args:
 88            table: the source table.
 89            column: the target column.
 90            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 91            normalize: whether to normalize identifiers according to the dialect of interest.
 92
 93        Returns:
 94            The resulting column type.
 95        """
 96
 97    def has_column(
 98        self,
 99        table: exp.Table | str,
100        column: exp.Column | str,
101        dialect: DialectType = None,
102        normalize: t.Optional[bool] = None,
103    ) -> bool:
104        """
105        Returns whether `column` appears in `table`'s schema.
106
107        Args:
108            table: the source table.
109            column: the target column.
110            dialect: the SQL dialect that will be used to parse `table` if it's a string.
111            normalize: whether to normalize identifiers according to the dialect of interest.
112
113        Returns:
114            True if the column appears in the schema, False otherwise.
115        """
116        name = column if isinstance(column, str) else column.name
117        return name in self.column_names(table, dialect=dialect, normalize=normalize)
118
119    def get_udf_type(
120        self,
121        udf: exp.Anonymous | str,
122        dialect: DialectType = None,
123        normalize: t.Optional[bool] = None,
124    ) -> exp.DataType:
125        """
126        Get the return type of a UDF.
127
128        Args:
129            udf: the UDF expression or string.
130            dialect: the SQL dialect for parsing string arguments.
131            normalize: whether to normalize identifiers.
132
133        Returns:
134            The return type as a DataType, or UNKNOWN if not found.
135        """
136        return exp.DataType.build("unknown")
137
138    @property
139    @abc.abstractmethod
140    def supported_table_args(self) -> t.Tuple[str, ...]:
141        """
142        Table arguments this schema support, e.g. `("this", "db", "catalog")`
143        """
144
145    @property
146    def empty(self) -> bool:
147        """Returns whether the schema is empty."""
148        return True

Abstract base class for database schemas

dialect: Optional[sqlglot.dialects.Dialect]
26    @property
27    def dialect(self) -> t.Optional[Dialect]:
28        """
29        Returns None by default. Subclasses that require dialect-specific
30        behavior should override this property.
31        """
32        return None

Returns None by default. Subclasses that require dialect-specific behavior should override this property.

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
34    @abc.abstractmethod
35    def add_table(
36        self,
37        table: exp.Table | str,
38        column_mapping: t.Optional[ColumnMapping] = None,
39        dialect: DialectType = None,
40        normalize: t.Optional[bool] = None,
41        match_depth: bool = True,
42    ) -> None:
43        """
44        Register or update a table. Some implementing classes may require column information to also be provided.
45        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
46
47        Args:
48            table: the `Table` expression instance or string representing the table.
49            column_mapping: a column mapping that describes the structure of the table.
50            dialect: the SQL dialect that will be used to parse `table` if it's a string.
51            normalize: whether to normalize identifiers according to the dialect of interest.
52            match_depth: whether to enforce that the table must match the schema's depth or not.
53        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
55    @abc.abstractmethod
56    def column_names(
57        self,
58        table: exp.Table | str,
59        only_visible: bool = False,
60        dialect: DialectType = None,
61        normalize: t.Optional[bool] = None,
62    ) -> t.Sequence[str]:
63        """
64        Get the column names for a table.
65
66        Args:
67            table: the `Table` expression instance.
68            only_visible: whether to include invisible columns.
69            dialect: the SQL dialect that will be used to parse `table` if it's a string.
70            normalize: whether to normalize identifiers according to the dialect of interest.
71
72        Returns:
73            The sequence of column names.
74        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
76    @abc.abstractmethod
77    def get_column_type(
78        self,
79        table: exp.Table | str,
80        column: exp.Column | str,
81        dialect: DialectType = None,
82        normalize: t.Optional[bool] = None,
83    ) -> exp.DataType:
84        """
85        Get the `sqlglot.exp.DataType` type of a column in the schema.
86
87        Args:
88            table: the source table.
89            column: the target column.
90            dialect: the SQL dialect that will be used to parse `table` if it's a string.
91            normalize: whether to normalize identifiers according to the dialect of interest.
92
93        Returns:
94            The resulting column type.
95        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 97    def has_column(
 98        self,
 99        table: exp.Table | str,
100        column: exp.Column | str,
101        dialect: DialectType = None,
102        normalize: t.Optional[bool] = None,
103    ) -> bool:
104        """
105        Returns whether `column` appears in `table`'s schema.
106
107        Args:
108            table: the source table.
109            column: the target column.
110            dialect: the SQL dialect that will be used to parse `table` if it's a string.
111            normalize: whether to normalize identifiers according to the dialect of interest.
112
113        Returns:
114            True if the column appears in the schema, False otherwise.
115        """
116        name = column if isinstance(column, str) else column.name
117        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def get_udf_type( self, udf: sqlglot.expressions.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
119    def get_udf_type(
120        self,
121        udf: exp.Anonymous | str,
122        dialect: DialectType = None,
123        normalize: t.Optional[bool] = None,
124    ) -> exp.DataType:
125        """
126        Get the return type of a UDF.
127
128        Args:
129            udf: the UDF expression or string.
130            dialect: the SQL dialect for parsing string arguments.
131            normalize: whether to normalize identifiers.
132
133        Returns:
134            The return type as a DataType, or UNKNOWN if not found.
135        """
136        return exp.DataType.build("unknown")

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string.
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

supported_table_args: Tuple[str, ...]
138    @property
139    @abc.abstractmethod
140    def supported_table_args(self) -> t.Tuple[str, ...]:
141        """
142        Table arguments this schema support, e.g. `("this", "db", "catalog")`
143        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
145    @property
146    def empty(self) -> bool:
147        """Returns whether the schema is empty."""
148        return True

Returns whether the schema is empty.

@mypyc_attr(allow_interpreted_subclasses=True)
class AbstractMappingSchema:
151@mypyc_attr(allow_interpreted_subclasses=True)
152class AbstractMappingSchema:
153    def __init__(
154        self,
155        mapping: t.Optional[t.Dict] = None,
156        udf_mapping: t.Optional[t.Dict] = None,
157    ) -> None:
158        self.mapping = mapping or {}
159        self.mapping_trie = new_trie(
160            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
161        )
162
163        self.udf_mapping = udf_mapping or {}
164        self.udf_trie = new_trie(
165            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
166        )
167
168        self._supported_table_args: t.Tuple[str, ...] = tuple()
169
170    @property
171    def empty(self) -> bool:
172        return not self.mapping
173
174    def depth(self) -> int:
175        return dict_depth(self.mapping)
176
177    def udf_depth(self) -> int:
178        return dict_depth(self.udf_mapping)
179
180    @property
181    def supported_table_args(self) -> t.Tuple[str, ...]:
182        if not self._supported_table_args and self.mapping:
183            depth = self.depth()
184
185            if not depth:  # None
186                self._supported_table_args = tuple()
187            elif 1 <= depth <= 3:
188                self._supported_table_args = exp.TABLE_PARTS[:depth]
189            else:
190                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
191
192        return self._supported_table_args
193
194    def table_parts(self, table: exp.Table) -> t.List[str]:
195        return [p.name for p in reversed(table.parts)]
196
197    def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
198        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
199        parent = udf.parent
200        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
201        return list(reversed(parts))[0 : self.udf_depth()]
202
203    def _find_in_trie(
204        self,
205        parts: t.List[str],
206        trie: t.Dict,
207        raise_on_missing: bool,
208    ) -> t.Optional[t.List[str]]:
209        value, trie = in_trie(trie, parts)
210
211        if value == TrieResult.FAILED:
212            return None
213
214        if value == TrieResult.PREFIX:
215            possibilities = flatten_schema(trie)
216
217            if len(possibilities) == 1:
218                parts.extend(possibilities[0])
219            else:
220                if raise_on_missing:
221                    joined_parts = ".".join(parts)
222                    message = ", ".join(".".join(p) for p in possibilities)
223                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
224
225                return None
226
227        return parts
228
229    def find(
230        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
231    ) -> t.Optional[t.Any]:
232        """
233        Returns the schema of a given table.
234
235        Args:
236            table: the target table.
237            raise_on_missing: whether to raise in case the schema is not found.
238            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
239
240        Returns:
241            The schema of the target table.
242        """
243        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
244        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
245
246        if resolved_parts is None:
247            return None
248
249        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
250
251    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
252        """
253        Returns the return type of a given UDF.
254
255        Args:
256            udf: the target UDF expression.
257            raise_on_missing: whether to raise if the UDF is not found.
258
259        Returns:
260            The return type of the UDF, or None if not found.
261        """
262        parts = self.udf_parts(udf)
263        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
264
265        if resolved_parts is None:
266            return None
267
268        return nested_get(
269            self.udf_mapping,
270            *zip(resolved_parts, reversed(resolved_parts)),
271            raise_on_missing=raise_on_missing,
272        )
273
274    def nested_get(
275        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
276    ) -> t.Optional[t.Any]:
277        return nested_get(
278            d or self.mapping,
279            *zip(self.supported_table_args, reversed(parts)),
280            raise_on_missing=raise_on_missing,
281        )
AbstractMappingSchema(mapping: Optional[Dict] = None, udf_mapping: Optional[Dict] = None)
153    def __init__(
154        self,
155        mapping: t.Optional[t.Dict] = None,
156        udf_mapping: t.Optional[t.Dict] = None,
157    ) -> None:
158        self.mapping = mapping or {}
159        self.mapping_trie = new_trie(
160            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
161        )
162
163        self.udf_mapping = udf_mapping or {}
164        self.udf_trie = new_trie(
165            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
166        )
167
168        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
udf_mapping
udf_trie
empty: bool
170    @property
171    def empty(self) -> bool:
172        return not self.mapping
def depth(self) -> int:
174    def depth(self) -> int:
175        return dict_depth(self.mapping)
def udf_depth(self) -> int:
177    def udf_depth(self) -> int:
178        return dict_depth(self.udf_mapping)
supported_table_args: Tuple[str, ...]
180    @property
181    def supported_table_args(self) -> t.Tuple[str, ...]:
182        if not self._supported_table_args and self.mapping:
183            depth = self.depth()
184
185            if not depth:  # None
186                self._supported_table_args = tuple()
187            elif 1 <= depth <= 3:
188                self._supported_table_args = exp.TABLE_PARTS[:depth]
189            else:
190                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
191
192        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
194    def table_parts(self, table: exp.Table) -> t.List[str]:
195        return [p.name for p in reversed(table.parts)]
def udf_parts(self, udf: sqlglot.expressions.Anonymous) -> List[str]:
197    def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
198        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
199        parent = udf.parent
200        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
201        return list(reversed(parts))[0 : self.udf_depth()]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
229    def find(
230        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
231    ) -> t.Optional[t.Any]:
232        """
233        Returns the schema of a given table.
234
235        Args:
236            table: the target table.
237            raise_on_missing: whether to raise in case the schema is not found.
238            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
239
240        Returns:
241            The schema of the target table.
242        """
243        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
244        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
245
246        if resolved_parts is None:
247            return None
248
249        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def find_udf( self, udf: sqlglot.expressions.Anonymous, raise_on_missing: bool = False) -> Optional[Any]:
251    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
252        """
253        Returns the return type of a given UDF.
254
255        Args:
256            udf: the target UDF expression.
257            raise_on_missing: whether to raise if the UDF is not found.
258
259        Returns:
260            The return type of the UDF, or None if not found.
261        """
262        parts = self.udf_parts(udf)
263        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
264
265        if resolved_parts is None:
266            return None
267
268        return nested_get(
269            self.udf_mapping,
270            *zip(resolved_parts, reversed(resolved_parts)),
271            raise_on_missing=raise_on_missing,
272        )

Returns the return type of a given UDF.

Arguments:
  • udf: the target UDF expression.
  • raise_on_missing: whether to raise if the UDF is not found.
Returns:

The return type of the UDF, or None if not found.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
274    def nested_get(
275        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
276    ) -> t.Optional[t.Any]:
277        return nested_get(
278            d or self.mapping,
279            *zip(self.supported_table_args, reversed(parts)),
280            raise_on_missing=raise_on_missing,
281        )
class MappingSchema(AbstractMappingSchema, Schema):
284class MappingSchema(AbstractMappingSchema, Schema):
285    """
286    Schema based on a nested mapping.
287
288    Args:
289        schema: Mapping in one of the following forms:
290            1. {table: {col: type}}
291            2. {db: {table: {col: type}}}
292            3. {catalog: {db: {table: {col: type}}}}
293            4. None - Tables will be added later
294        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
295            are assumed to be visible. The nesting should mirror that of the schema:
296            1. {table: set(*cols)}}
297            2. {db: {table: set(*cols)}}}
298            3. {catalog: {db: {table: set(*cols)}}}}
299        dialect: The dialect to be used for custom type mappings & parsing string arguments.
300        normalize: Whether to normalize identifier names according to the given dialect or not.
301    """
302
303    def __init__(
304        self,
305        schema: t.Optional[t.Dict] = None,
306        visible: t.Optional[t.Dict] = None,
307        dialect: DialectType = None,
308        normalize: bool = True,
309        udf_mapping: t.Optional[t.Dict] = None,
310    ) -> None:
311        self.visible = {} if visible is None else visible
312        self.normalize = normalize
313        self._dialect = Dialect.get_or_raise(dialect)
314        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
315        self._normalized_table_cache: t.Dict[t.Tuple[exp.Table, DialectType, bool], exp.Table] = {}
316        self._normalized_name_cache: t.Dict[t.Tuple[str, DialectType, bool, bool], str] = {}
317        self._depth = 0
318        schema = {} if schema is None else schema
319        udf_mapping = {} if udf_mapping is None else udf_mapping
320
321        super().__init__(
322            self._normalize(schema) if self.normalize else schema,
323            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
324        )
325
326    @property
327    def dialect(self) -> Dialect:
328        """Returns the dialect for this mapping schema."""
329        return self._dialect
330
331    @classmethod
332    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
333        return MappingSchema(
334            schema=mapping_schema.mapping,
335            visible=mapping_schema.visible,
336            dialect=mapping_schema.dialect,
337            normalize=mapping_schema.normalize,
338            udf_mapping=mapping_schema.udf_mapping,
339        )
340
341    def find(
342        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
343    ) -> t.Optional[t.Any]:
344        schema = super().find(
345            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
346        )
347        if ensure_data_types and isinstance(schema, dict):
348            schema = {
349                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
350                for col, dtype in schema.items()
351            }
352
353        return schema
354
355    def copy(self, **kwargs) -> MappingSchema:
356        return MappingSchema(
357            **{  # type: ignore
358                "schema": self.mapping.copy(),
359                "visible": self.visible.copy(),
360                "dialect": self.dialect,
361                "normalize": self.normalize,
362                "udf_mapping": self.udf_mapping.copy(),
363                **kwargs,
364            }
365        )
366
367    def add_table(
368        self,
369        table: exp.Table | str,
370        column_mapping: t.Optional[ColumnMapping] = None,
371        dialect: DialectType = None,
372        normalize: t.Optional[bool] = None,
373        match_depth: bool = True,
374    ) -> None:
375        """
376        Register or update a table. Updates are only performed if a new column mapping is provided.
377        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
378
379        Args:
380            table: the `Table` expression instance or string representing the table.
381            column_mapping: a column mapping that describes the structure of the table.
382            dialect: the SQL dialect that will be used to parse `table` if it's a string.
383            normalize: whether to normalize identifiers according to the dialect of interest.
384            match_depth: whether to enforce that the table must match the schema's depth or not.
385        """
386        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
387
388        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
389            raise SchemaError(
390                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
391                f"schema's nesting level: {self.depth()}."
392            )
393
394        normalized_column_mapping = {
395            self._normalize_name(key, dialect=dialect, normalize=normalize): value
396            for key, value in ensure_column_mapping(column_mapping).items()
397        }
398
399        schema = self.find(normalized_table, raise_on_missing=False)
400        if schema and not normalized_column_mapping:
401            return
402
403        parts = self.table_parts(normalized_table)
404
405        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
406        new_trie([parts], self.mapping_trie)
407
408    def column_names(
409        self,
410        table: exp.Table | str,
411        only_visible: bool = False,
412        dialect: DialectType = None,
413        normalize: t.Optional[bool] = None,
414    ) -> t.List[str]:
415        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
416
417        schema = self.find(normalized_table)
418        if schema is None:
419            return []
420
421        if not only_visible or not self.visible:
422            return list(schema)
423
424        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
425        return [col for col in schema if col in visible]
426
427    def get_column_type(
428        self,
429        table: exp.Table | str,
430        column: exp.Column | str,
431        dialect: DialectType = None,
432        normalize: t.Optional[bool] = None,
433    ) -> exp.DataType:
434        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
435
436        normalized_column_name = self._normalize_name(
437            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
438        )
439
440        table_schema = self.find(normalized_table, raise_on_missing=False)
441        if table_schema:
442            column_type = table_schema.get(normalized_column_name)
443
444            if isinstance(column_type, exp.DataType):
445                return column_type
446            elif isinstance(column_type, str):
447                return self._to_data_type(column_type, dialect=dialect)
448
449        return exp.DataType.build("unknown")
450
451    def get_udf_type(
452        self,
453        udf: exp.Anonymous | str,
454        dialect: DialectType = None,
455        normalize: t.Optional[bool] = None,
456    ) -> exp.DataType:
457        """
458        Get the return type of a UDF.
459
460        Args:
461            udf: the UDF expression or string (e.g., "db.my_func()").
462            dialect: the SQL dialect for parsing string arguments.
463            normalize: whether to normalize identifiers.
464
465        Returns:
466            The return type as a DataType, or UNKNOWN if not found.
467        """
468        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
469        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
470
471        if resolved_parts is None:
472            return exp.DataType.build("unknown")
473
474        udf_type = nested_get(
475            self.udf_mapping,
476            *zip(resolved_parts, reversed(resolved_parts)),
477            raise_on_missing=False,
478        )
479
480        if isinstance(udf_type, exp.DataType):
481            return udf_type
482        elif isinstance(udf_type, str):
483            return self._to_data_type(udf_type, dialect=dialect)
484
485        return exp.DataType.build("unknown")
486
487    def has_column(
488        self,
489        table: exp.Table | str,
490        column: exp.Column | str,
491        dialect: DialectType = None,
492        normalize: t.Optional[bool] = None,
493    ) -> bool:
494        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
495
496        normalized_column_name = self._normalize_name(
497            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
498        )
499
500        table_schema = self.find(normalized_table, raise_on_missing=False)
501        return normalized_column_name in table_schema if table_schema else False
502
503    def _normalize(self, schema: t.Dict) -> t.Dict:
504        """
505        Normalizes all identifiers in the schema.
506
507        Args:
508            schema: the schema to normalize.
509
510        Returns:
511            The normalized schema mapping.
512        """
513        normalized_mapping: t.Dict = {}
514        flattened_schema = flatten_schema(schema)
515        error_msg = "Table {} must match the schema's nesting level: {}."
516
517        for keys in flattened_schema:
518            columns = nested_get(schema, *zip(keys, keys))
519
520            if not isinstance(columns, dict):
521                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
522            if not columns:
523                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
524            if isinstance(first(columns.values()), dict):
525                raise SchemaError(
526                    error_msg.format(
527                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
528                    ),
529                )
530
531            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
532            for column_name, column_type in columns.items():
533                nested_set(
534                    normalized_mapping,
535                    normalized_keys + [self._normalize_name(column_name)],
536                    column_type,
537                )
538
539        return normalized_mapping
540
541    def _normalize_udfs(self, udfs: t.Dict) -> t.Dict:
542        """
543        Normalizes all identifiers in the UDF mapping.
544
545        Args:
546            udfs: the UDF mapping to normalize.
547
548        Returns:
549            The normalized UDF mapping.
550        """
551        normalized_mapping: t.Dict = {}
552
553        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
554            udf_type = nested_get(udfs, *zip(keys, keys))
555            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
556            nested_set(normalized_mapping, normalized_keys, udf_type)
557
558        return normalized_mapping
559
560    def _normalize_udf(
561        self,
562        udf: exp.Anonymous | str,
563        dialect: DialectType = None,
564        normalize: t.Optional[bool] = None,
565    ) -> t.List[str]:
566        """
567        Extract and normalize UDF parts for lookup.
568
569        Args:
570            udf: the UDF expression or qualified string (e.g., "db.my_func()").
571            dialect: the SQL dialect for parsing.
572            normalize: whether to normalize identifiers.
573
574        Returns:
575            A list of normalized UDF parts (reversed for trie lookup).
576        """
577        dialect = dialect or self.dialect
578        normalize = self.normalize if normalize is None else normalize
579
580        if isinstance(udf, str):
581            parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect)
582
583            if isinstance(parsed, exp.Anonymous):
584                udf = parsed
585            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
586                udf = parsed.expression
587            else:
588                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
589        parts = self.udf_parts(udf)
590
591        if normalize:
592            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
593
594        return parts
595
596    def _normalize_table(
597        self,
598        table: exp.Table | str,
599        dialect: DialectType = None,
600        normalize: t.Optional[bool] = None,
601    ) -> exp.Table:
602        dialect = dialect or self.dialect
603        normalize = self.normalize if normalize is None else normalize
604
605        # Cache normalized tables by object id for exp.Table inputs
606        # This is effective when the same Table object is looked up multiple times
607        if isinstance(table, exp.Table) and (
608            cached := self._normalized_table_cache.get((table, dialect, normalize))
609        ):
610            return cached
611
612        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
613
614        if normalize:
615            for part in normalized_table.parts:
616                if isinstance(part, exp.Identifier):
617                    part.replace(
618                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
619                    )
620
621        self._normalized_table_cache[(t.cast(exp.Table, table), dialect, normalize)] = (
622            normalized_table
623        )
624        return normalized_table
625
626    def _normalize_name(
627        self,
628        name: str | exp.Identifier,
629        dialect: DialectType = None,
630        is_table: bool = False,
631        normalize: t.Optional[bool] = None,
632    ) -> str:
633        normalize = self.normalize if normalize is None else normalize
634
635        dialect = dialect or self.dialect
636        name_str = name if isinstance(name, str) else name.name
637        cache_key = (name_str, dialect, is_table, normalize)
638
639        if cached := self._normalized_name_cache.get(cache_key):
640            return cached
641
642        result = normalize_name(
643            name,
644            dialect=dialect,
645            is_table=is_table,
646            normalize=normalize,
647        ).name
648
649        self._normalized_name_cache[cache_key] = result
650        return result
651
652    def depth(self) -> int:
653        if not self.empty and not self._depth:
654            # The columns themselves are a mapping, but we don't want to include those
655            self._depth = super().depth() - 1
656        return self._depth
657
658    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
659        """
660        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
661
662        Args:
663            schema_type: the type we want to convert.
664            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
665
666        Returns:
667            The resulting expression type.
668        """
669        if schema_type not in self._type_mapping_cache:
670            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
671            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
672
673            try:
674                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
675                expression.transform(dialect.normalize_identifier, copy=False)
676                self._type_mapping_cache[schema_type] = expression
677            except AttributeError:
678                in_dialect = f" in dialect {dialect}" if dialect else ""
679                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
680
681        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool = True, udf_mapping: Optional[Dict] = None)
303    def __init__(
304        self,
305        schema: t.Optional[t.Dict] = None,
306        visible: t.Optional[t.Dict] = None,
307        dialect: DialectType = None,
308        normalize: bool = True,
309        udf_mapping: t.Optional[t.Dict] = None,
310    ) -> None:
311        self.visible = {} if visible is None else visible
312        self.normalize = normalize
313        self._dialect = Dialect.get_or_raise(dialect)
314        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
315        self._normalized_table_cache: t.Dict[t.Tuple[exp.Table, DialectType, bool], exp.Table] = {}
316        self._normalized_name_cache: t.Dict[t.Tuple[str, DialectType, bool, bool], str] = {}
317        self._depth = 0
318        schema = {} if schema is None else schema
319        udf_mapping = {} if udf_mapping is None else udf_mapping
320
321        super().__init__(
322            self._normalize(schema) if self.normalize else schema,
323            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
324        )
visible
normalize
dialect: sqlglot.dialects.Dialect
326    @property
327    def dialect(self) -> Dialect:
328        """Returns the dialect for this mapping schema."""
329        return self._dialect

Returns the dialect for this mapping schema.

@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
331    @classmethod
332    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
333        return MappingSchema(
334            schema=mapping_schema.mapping,
335            visible=mapping_schema.visible,
336            dialect=mapping_schema.dialect,
337            normalize=mapping_schema.normalize,
338            udf_mapping=mapping_schema.udf_mapping,
339        )
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
341    def find(
342        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
343    ) -> t.Optional[t.Any]:
344        schema = super().find(
345            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
346        )
347        if ensure_data_types and isinstance(schema, dict):
348            schema = {
349                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
350                for col, dtype in schema.items()
351            }
352
353        return schema

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def copy(self, **kwargs) -> MappingSchema:
355    def copy(self, **kwargs) -> MappingSchema:
356        return MappingSchema(
357            **{  # type: ignore
358                "schema": self.mapping.copy(),
359                "visible": self.visible.copy(),
360                "dialect": self.dialect,
361                "normalize": self.normalize,
362                "udf_mapping": self.udf_mapping.copy(),
363                **kwargs,
364            }
365        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
367    def add_table(
368        self,
369        table: exp.Table | str,
370        column_mapping: t.Optional[ColumnMapping] = None,
371        dialect: DialectType = None,
372        normalize: t.Optional[bool] = None,
373        match_depth: bool = True,
374    ) -> None:
375        """
376        Register or update a table. Updates are only performed if a new column mapping is provided.
377        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
378
379        Args:
380            table: the `Table` expression instance or string representing the table.
381            column_mapping: a column mapping that describes the structure of the table.
382            dialect: the SQL dialect that will be used to parse `table` if it's a string.
383            normalize: whether to normalize identifiers according to the dialect of interest.
384            match_depth: whether to enforce that the table must match the schema's depth or not.
385        """
386        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
387
388        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
389            raise SchemaError(
390                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
391                f"schema's nesting level: {self.depth()}."
392            )
393
394        normalized_column_mapping = {
395            self._normalize_name(key, dialect=dialect, normalize=normalize): value
396            for key, value in ensure_column_mapping(column_mapping).items()
397        }
398
399        schema = self.find(normalized_table, raise_on_missing=False)
400        if schema and not normalized_column_mapping:
401            return
402
403        parts = self.table_parts(normalized_table)
404
405        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
406        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
408    def column_names(
409        self,
410        table: exp.Table | str,
411        only_visible: bool = False,
412        dialect: DialectType = None,
413        normalize: t.Optional[bool] = None,
414    ) -> t.List[str]:
415        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
416
417        schema = self.find(normalized_table)
418        if schema is None:
419            return []
420
421        if not only_visible or not self.visible:
422            return list(schema)
423
424        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
425        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
427    def get_column_type(
428        self,
429        table: exp.Table | str,
430        column: exp.Column | str,
431        dialect: DialectType = None,
432        normalize: t.Optional[bool] = None,
433    ) -> exp.DataType:
434        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
435
436        normalized_column_name = self._normalize_name(
437            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
438        )
439
440        table_schema = self.find(normalized_table, raise_on_missing=False)
441        if table_schema:
442            column_type = table_schema.get(normalized_column_name)
443
444            if isinstance(column_type, exp.DataType):
445                return column_type
446            elif isinstance(column_type, str):
447                return self._to_data_type(column_type, dialect=dialect)
448
449        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def get_udf_type( self, udf: sqlglot.expressions.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
451    def get_udf_type(
452        self,
453        udf: exp.Anonymous | str,
454        dialect: DialectType = None,
455        normalize: t.Optional[bool] = None,
456    ) -> exp.DataType:
457        """
458        Get the return type of a UDF.
459
460        Args:
461            udf: the UDF expression or string (e.g., "db.my_func()").
462            dialect: the SQL dialect for parsing string arguments.
463            normalize: whether to normalize identifiers.
464
465        Returns:
466            The return type as a DataType, or UNKNOWN if not found.
467        """
468        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
469        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
470
471        if resolved_parts is None:
472            return exp.DataType.build("unknown")
473
474        udf_type = nested_get(
475            self.udf_mapping,
476            *zip(resolved_parts, reversed(resolved_parts)),
477            raise_on_missing=False,
478        )
479
480        if isinstance(udf_type, exp.DataType):
481            return udf_type
482        elif isinstance(udf_type, str):
483            return self._to_data_type(udf_type, dialect=dialect)
484
485        return exp.DataType.build("unknown")

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string (e.g., "db.my_func()").
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
487    def has_column(
488        self,
489        table: exp.Table | str,
490        column: exp.Column | str,
491        dialect: DialectType = None,
492        normalize: t.Optional[bool] = None,
493    ) -> bool:
494        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
495
496        normalized_column_name = self._normalize_name(
497            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
498        )
499
500        table_schema = self.find(normalized_table, raise_on_missing=False)
501        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
652    def depth(self) -> int:
653        if not self.empty and not self._depth:
654            # The columns themselves are a mapping, but we don't want to include those
655            self._depth = super().depth() - 1
656        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
684def normalize_name(
685    identifier: str | exp.Identifier,
686    dialect: DialectType = None,
687    is_table: bool = False,
688    normalize: t.Optional[bool] = True,
689) -> exp.Identifier:
690    if isinstance(identifier, str):
691        identifier = exp.parse_identifier(identifier, dialect=dialect)
692
693    if not normalize:
694        return identifier
695
696    # this is used for normalize_identifier, bigquery has special rules pertaining tables
697    identifier.meta["is_table"] = is_table
698    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
701def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
702    if isinstance(schema, Schema):
703        return schema
704
705    return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: Union[Dict, str, List, NoneType]) -> Dict:
708def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
709    if mapping is None:
710        return {}
711    elif isinstance(mapping, dict):
712        return mapping
713    elif isinstance(mapping, str):
714        col_name_type_strs = [x.strip() for x in mapping.split(",")]
715        return {
716            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
717            for name_type_str in col_name_type_strs
718        }
719    elif isinstance(mapping, list):
720        return {x.strip(): None for x in mapping}
721
722    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: Optional[int] = None, keys: Optional[List[str]] = None) -> List[List[str]]:
725def flatten_schema(
726    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
727) -> t.List[t.List[str]]:
728    tables = []
729    keys = keys or []
730    depth = dict_depth(schema) - 1 if depth is None else depth
731
732    for k, v in schema.items():
733        if depth == 1 or not isinstance(v, dict):
734            tables.append(keys + [k])
735        elif depth >= 2:
736            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
737
738    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
741def nested_get(
742    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
743) -> t.Optional[t.Any]:
744    """
745    Get a value for a nested dictionary.
746
747    Args:
748        d: the dictionary to search.
749        *path: tuples of (name, key), where:
750            `key` is the key in the dictionary to get.
751            `name` is a string to use in the error if `key` isn't found.
752
753    Returns:
754        The value or None if it doesn't exist.
755    """
756    result: t.Any = d
757    for name, key in path:
758        result = result.get(key)
759        if result is None:
760            if raise_on_missing:
761                name = "table" if name == "this" else name
762                raise ValueError(f"Unknown {name}: {key}")
763            return None
764
765    return result

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
768def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
769    """
770    In-place set a value for a nested dictionary
771
772    Example:
773        >>> nested_set({}, ["top_key", "second_key"], "value")
774        {'top_key': {'second_key': 'value'}}
775
776        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
777        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
778
779    Args:
780        d: dictionary to update.
781        keys: the keys that makeup the path to `value`.
782        value: the value to set in the dictionary for the given key path.
783
784    Returns:
785        The (possibly) updated dictionary.
786    """
787    if not keys:
788        return d
789
790    if len(keys) == 1:
791        d[keys[0]] = value
792        return d
793
794    subd = d
795    for key in keys[:-1]:
796        if key not in subd:
797            subd = subd.setdefault(key, {})
798        else:
799            subd = subd[key]
800
801    subd[keys[-1]] = value
802    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.