Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15    ColumnMapping = t.Union[t.Dict, str, t.List]
 16
 17
 18class Schema(abc.ABC):
 19    """Abstract base class for database schemas"""
 20
 21    @property
 22    def dialect(self) -> t.Optional[Dialect]:
 23        """
 24        Returns None by default. Subclasses that require dialect-specific
 25        behavior should override this property.
 26        """
 27        return None
 28
 29    @abc.abstractmethod
 30    def add_table(
 31        self,
 32        table: exp.Table | str,
 33        column_mapping: t.Optional[ColumnMapping] = None,
 34        dialect: DialectType = None,
 35        normalize: t.Optional[bool] = None,
 36        match_depth: bool = True,
 37    ) -> None:
 38        """
 39        Register or update a table. Some implementing classes may require column information to also be provided.
 40        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 41
 42        Args:
 43            table: the `Table` expression instance or string representing the table.
 44            column_mapping: a column mapping that describes the structure of the table.
 45            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 46            normalize: whether to normalize identifiers according to the dialect of interest.
 47            match_depth: whether to enforce that the table must match the schema's depth or not.
 48        """
 49
 50    @abc.abstractmethod
 51    def column_names(
 52        self,
 53        table: exp.Table | str,
 54        only_visible: bool = False,
 55        dialect: DialectType = None,
 56        normalize: t.Optional[bool] = None,
 57    ) -> t.Sequence[str]:
 58        """
 59        Get the column names for a table.
 60
 61        Args:
 62            table: the `Table` expression instance.
 63            only_visible: whether to include invisible columns.
 64            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 65            normalize: whether to normalize identifiers according to the dialect of interest.
 66
 67        Returns:
 68            The sequence of column names.
 69        """
 70
 71    @abc.abstractmethod
 72    def get_column_type(
 73        self,
 74        table: exp.Table | str,
 75        column: exp.Column | str,
 76        dialect: DialectType = None,
 77        normalize: t.Optional[bool] = None,
 78    ) -> exp.DataType:
 79        """
 80        Get the `sqlglot.exp.DataType` type of a column in the schema.
 81
 82        Args:
 83            table: the source table.
 84            column: the target column.
 85            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 86            normalize: whether to normalize identifiers according to the dialect of interest.
 87
 88        Returns:
 89            The resulting column type.
 90        """
 91
 92    def has_column(
 93        self,
 94        table: exp.Table | str,
 95        column: exp.Column | str,
 96        dialect: DialectType = None,
 97        normalize: t.Optional[bool] = None,
 98    ) -> bool:
 99        """
100        Returns whether `column` appears in `table`'s schema.
101
102        Args:
103            table: the source table.
104            column: the target column.
105            dialect: the SQL dialect that will be used to parse `table` if it's a string.
106            normalize: whether to normalize identifiers according to the dialect of interest.
107
108        Returns:
109            True if the column appears in the schema, False otherwise.
110        """
111        name = column if isinstance(column, str) else column.name
112        return name in self.column_names(table, dialect=dialect, normalize=normalize)
113
114    def get_udf_type(
115        self,
116        udf: exp.Anonymous | str,
117        dialect: DialectType = None,
118        normalize: t.Optional[bool] = None,
119    ) -> exp.DataType:
120        """
121        Get the return type of a UDF.
122
123        Args:
124            udf: the UDF expression or string.
125            dialect: the SQL dialect for parsing string arguments.
126            normalize: whether to normalize identifiers.
127
128        Returns:
129            The return type as a DataType, or UNKNOWN if not found.
130        """
131        return exp.DataType.build("unknown")
132
133    @property
134    @abc.abstractmethod
135    def supported_table_args(self) -> t.Tuple[str, ...]:
136        """
137        Table arguments this schema support, e.g. `("this", "db", "catalog")`
138        """
139
140    @property
141    def empty(self) -> bool:
142        """Returns whether the schema is empty."""
143        return True
144
145
146class AbstractMappingSchema:
147    def __init__(
148        self,
149        mapping: t.Optional[t.Dict] = None,
150        udf_mapping: t.Optional[t.Dict] = None,
151    ) -> None:
152        self.mapping = mapping or {}
153        self.mapping_trie = new_trie(
154            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
155        )
156
157        self.udf_mapping = udf_mapping or {}
158        self.udf_trie = new_trie(
159            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
160        )
161
162        self._supported_table_args: t.Tuple[str, ...] = tuple()
163
164    @property
165    def empty(self) -> bool:
166        return not self.mapping
167
168    def depth(self) -> int:
169        return dict_depth(self.mapping)
170
171    def udf_depth(self) -> int:
172        return dict_depth(self.udf_mapping)
173
174    @property
175    def supported_table_args(self) -> t.Tuple[str, ...]:
176        if not self._supported_table_args and self.mapping:
177            depth = self.depth()
178
179            if not depth:  # None
180                self._supported_table_args = tuple()
181            elif 1 <= depth <= 3:
182                self._supported_table_args = exp.TABLE_PARTS[:depth]
183            else:
184                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
185
186        return self._supported_table_args
187
188    def table_parts(self, table: exp.Table) -> t.List[str]:
189        return [p.name for p in reversed(table.parts)]
190
191    def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
192        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
193        parent = udf.parent
194        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
195        return list(reversed(parts))[0 : self.udf_depth()]
196
197    def _find_in_trie(
198        self,
199        parts: t.List[str],
200        trie: t.Dict,
201        raise_on_missing: bool,
202    ) -> t.Optional[t.List[str]]:
203        value, trie = in_trie(trie, parts)
204
205        if value == TrieResult.FAILED:
206            return None
207
208        if value == TrieResult.PREFIX:
209            possibilities = flatten_schema(trie)
210
211            if len(possibilities) == 1:
212                parts.extend(possibilities[0])
213            else:
214                if raise_on_missing:
215                    joined_parts = ".".join(parts)
216                    message = ", ".join(".".join(p) for p in possibilities)
217                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
218
219                return None
220
221        return parts
222
223    def find(
224        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
225    ) -> t.Optional[t.Any]:
226        """
227        Returns the schema of a given table.
228
229        Args:
230            table: the target table.
231            raise_on_missing: whether to raise in case the schema is not found.
232            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
233
234        Returns:
235            The schema of the target table.
236        """
237        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
238        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
239
240        if resolved_parts is None:
241            return None
242
243        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
244
245    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
246        """
247        Returns the return type of a given UDF.
248
249        Args:
250            udf: the target UDF expression.
251            raise_on_missing: whether to raise if the UDF is not found.
252
253        Returns:
254            The return type of the UDF, or None if not found.
255        """
256        parts = self.udf_parts(udf)
257        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
258
259        if resolved_parts is None:
260            return None
261
262        return nested_get(
263            self.udf_mapping,
264            *zip(resolved_parts, reversed(resolved_parts)),
265            raise_on_missing=raise_on_missing,
266        )
267
268    def nested_get(
269        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
270    ) -> t.Optional[t.Any]:
271        return nested_get(
272            d or self.mapping,
273            *zip(self.supported_table_args, reversed(parts)),
274            raise_on_missing=raise_on_missing,
275        )
276
277
278class MappingSchema(AbstractMappingSchema, Schema):
279    """
280    Schema based on a nested mapping.
281
282    Args:
283        schema: Mapping in one of the following forms:
284            1. {table: {col: type}}
285            2. {db: {table: {col: type}}}
286            3. {catalog: {db: {table: {col: type}}}}
287            4. None - Tables will be added later
288        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
289            are assumed to be visible. The nesting should mirror that of the schema:
290            1. {table: set(*cols)}}
291            2. {db: {table: set(*cols)}}}
292            3. {catalog: {db: {table: set(*cols)}}}}
293        dialect: The dialect to be used for custom type mappings & parsing string arguments.
294        normalize: Whether to normalize identifier names according to the given dialect or not.
295    """
296
297    def __init__(
298        self,
299        schema: t.Optional[t.Dict] = None,
300        visible: t.Optional[t.Dict] = None,
301        dialect: DialectType = None,
302        normalize: bool = True,
303        udf_mapping: t.Optional[t.Dict] = None,
304    ) -> None:
305        self.visible = {} if visible is None else visible
306        self.normalize = normalize
307        self._dialect = Dialect.get_or_raise(dialect)
308        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
309        self._depth = 0
310        schema = {} if schema is None else schema
311        udf_mapping = {} if udf_mapping is None else udf_mapping
312
313        super().__init__(
314            self._normalize(schema) if self.normalize else schema,
315            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
316        )
317
318    @property
319    def dialect(self) -> Dialect:
320        """Returns the dialect for this mapping schema."""
321        return self._dialect
322
323    @classmethod
324    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
325        return MappingSchema(
326            schema=mapping_schema.mapping,
327            visible=mapping_schema.visible,
328            dialect=mapping_schema.dialect,
329            normalize=mapping_schema.normalize,
330            udf_mapping=mapping_schema.udf_mapping,
331        )
332
333    def find(
334        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
335    ) -> t.Optional[t.Any]:
336        schema = super().find(
337            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
338        )
339        if ensure_data_types and isinstance(schema, dict):
340            schema = {
341                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
342                for col, dtype in schema.items()
343            }
344
345        return schema
346
347    def copy(self, **kwargs) -> MappingSchema:
348        return MappingSchema(
349            **{  # type: ignore
350                "schema": self.mapping.copy(),
351                "visible": self.visible.copy(),
352                "dialect": self.dialect,
353                "normalize": self.normalize,
354                "udf_mapping": self.udf_mapping.copy(),
355                **kwargs,
356            }
357        )
358
359    def add_table(
360        self,
361        table: exp.Table | str,
362        column_mapping: t.Optional[ColumnMapping] = None,
363        dialect: DialectType = None,
364        normalize: t.Optional[bool] = None,
365        match_depth: bool = True,
366    ) -> None:
367        """
368        Register or update a table. Updates are only performed if a new column mapping is provided.
369        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
370
371        Args:
372            table: the `Table` expression instance or string representing the table.
373            column_mapping: a column mapping that describes the structure of the table.
374            dialect: the SQL dialect that will be used to parse `table` if it's a string.
375            normalize: whether to normalize identifiers according to the dialect of interest.
376            match_depth: whether to enforce that the table must match the schema's depth or not.
377        """
378        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
379
380        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
381            raise SchemaError(
382                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
383                f"schema's nesting level: {self.depth()}."
384            )
385
386        normalized_column_mapping = {
387            self._normalize_name(key, dialect=dialect, normalize=normalize): value
388            for key, value in ensure_column_mapping(column_mapping).items()
389        }
390
391        schema = self.find(normalized_table, raise_on_missing=False)
392        if schema and not normalized_column_mapping:
393            return
394
395        parts = self.table_parts(normalized_table)
396
397        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
398        new_trie([parts], self.mapping_trie)
399
400    def column_names(
401        self,
402        table: exp.Table | str,
403        only_visible: bool = False,
404        dialect: DialectType = None,
405        normalize: t.Optional[bool] = None,
406    ) -> t.List[str]:
407        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
408
409        schema = self.find(normalized_table)
410        if schema is None:
411            return []
412
413        if not only_visible or not self.visible:
414            return list(schema)
415
416        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
417        return [col for col in schema if col in visible]
418
419    def get_column_type(
420        self,
421        table: exp.Table | str,
422        column: exp.Column | str,
423        dialect: DialectType = None,
424        normalize: t.Optional[bool] = None,
425    ) -> exp.DataType:
426        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
427
428        normalized_column_name = self._normalize_name(
429            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
430        )
431
432        table_schema = self.find(normalized_table, raise_on_missing=False)
433        if table_schema:
434            column_type = table_schema.get(normalized_column_name)
435
436            if isinstance(column_type, exp.DataType):
437                return column_type
438            elif isinstance(column_type, str):
439                return self._to_data_type(column_type, dialect=dialect)
440
441        return exp.DataType.build("unknown")
442
443    def get_udf_type(
444        self,
445        udf: exp.Anonymous | str,
446        dialect: DialectType = None,
447        normalize: t.Optional[bool] = None,
448    ) -> exp.DataType:
449        """
450        Get the return type of a UDF.
451
452        Args:
453            udf: the UDF expression or string (e.g., "db.my_func()").
454            dialect: the SQL dialect for parsing string arguments.
455            normalize: whether to normalize identifiers.
456
457        Returns:
458            The return type as a DataType, or UNKNOWN if not found.
459        """
460        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
461        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
462
463        if resolved_parts is None:
464            return exp.DataType.build("unknown")
465
466        udf_type = nested_get(
467            self.udf_mapping,
468            *zip(resolved_parts, reversed(resolved_parts)),
469            raise_on_missing=False,
470        )
471
472        if isinstance(udf_type, exp.DataType):
473            return udf_type
474        elif isinstance(udf_type, str):
475            return self._to_data_type(udf_type, dialect=dialect)
476
477        return exp.DataType.build("unknown")
478
479    def has_column(
480        self,
481        table: exp.Table | str,
482        column: exp.Column | str,
483        dialect: DialectType = None,
484        normalize: t.Optional[bool] = None,
485    ) -> bool:
486        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
487
488        normalized_column_name = self._normalize_name(
489            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
490        )
491
492        table_schema = self.find(normalized_table, raise_on_missing=False)
493        return normalized_column_name in table_schema if table_schema else False
494
495    def _normalize(self, schema: t.Dict) -> t.Dict:
496        """
497        Normalizes all identifiers in the schema.
498
499        Args:
500            schema: the schema to normalize.
501
502        Returns:
503            The normalized schema mapping.
504        """
505        normalized_mapping: t.Dict = {}
506        flattened_schema = flatten_schema(schema)
507        error_msg = "Table {} must match the schema's nesting level: {}."
508
509        for keys in flattened_schema:
510            columns = nested_get(schema, *zip(keys, keys))
511
512            if not isinstance(columns, dict):
513                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
514            if not columns:
515                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
516            if isinstance(first(columns.values()), dict):
517                raise SchemaError(
518                    error_msg.format(
519                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
520                    ),
521                )
522
523            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
524            for column_name, column_type in columns.items():
525                nested_set(
526                    normalized_mapping,
527                    normalized_keys + [self._normalize_name(column_name)],
528                    column_type,
529                )
530
531        return normalized_mapping
532
533    def _normalize_udfs(self, udfs: t.Dict) -> t.Dict:
534        """
535        Normalizes all identifiers in the UDF mapping.
536
537        Args:
538            udfs: the UDF mapping to normalize.
539
540        Returns:
541            The normalized UDF mapping.
542        """
543        normalized_mapping: t.Dict = {}
544
545        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
546            udf_type = nested_get(udfs, *zip(keys, keys))
547            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
548            nested_set(normalized_mapping, normalized_keys, udf_type)
549
550        return normalized_mapping
551
552    def _normalize_udf(
553        self,
554        udf: exp.Anonymous | str,
555        dialect: DialectType = None,
556        normalize: t.Optional[bool] = None,
557    ) -> t.List[str]:
558        """
559        Extract and normalize UDF parts for lookup.
560
561        Args:
562            udf: the UDF expression or qualified string (e.g., "db.my_func()").
563            dialect: the SQL dialect for parsing.
564            normalize: whether to normalize identifiers.
565
566        Returns:
567            A list of normalized UDF parts (reversed for trie lookup).
568        """
569        dialect = dialect or self.dialect
570        normalize = self.normalize if normalize is None else normalize
571
572        if isinstance(udf, str):
573            parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect)
574
575            if isinstance(parsed, exp.Anonymous):
576                udf = parsed
577            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
578                udf = parsed.expression
579            else:
580                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
581        parts = self.udf_parts(udf)
582
583        if normalize:
584            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
585
586        return parts
587
588    def _normalize_table(
589        self,
590        table: exp.Table | str,
591        dialect: DialectType = None,
592        normalize: t.Optional[bool] = None,
593    ) -> exp.Table:
594        dialect = dialect or self.dialect
595        normalize = self.normalize if normalize is None else normalize
596
597        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
598
599        if normalize:
600            for part in normalized_table.parts:
601                if isinstance(part, exp.Identifier):
602                    part.replace(
603                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
604                    )
605
606        return normalized_table
607
608    def _normalize_name(
609        self,
610        name: str | exp.Identifier,
611        dialect: DialectType = None,
612        is_table: bool = False,
613        normalize: t.Optional[bool] = None,
614    ) -> str:
615        return normalize_name(
616            name,
617            dialect=dialect or self.dialect,
618            is_table=is_table,
619            normalize=self.normalize if normalize is None else normalize,
620        ).name
621
622    def depth(self) -> int:
623        if not self.empty and not self._depth:
624            # The columns themselves are a mapping, but we don't want to include those
625            self._depth = super().depth() - 1
626        return self._depth
627
628    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
629        """
630        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
631
632        Args:
633            schema_type: the type we want to convert.
634            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
635
636        Returns:
637            The resulting expression type.
638        """
639        if schema_type not in self._type_mapping_cache:
640            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
641            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
642
643            try:
644                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
645                expression.transform(dialect.normalize_identifier, copy=False)
646                self._type_mapping_cache[schema_type] = expression
647            except AttributeError:
648                in_dialect = f" in dialect {dialect}" if dialect else ""
649                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
650
651        return self._type_mapping_cache[schema_type]
652
653
654def normalize_name(
655    identifier: str | exp.Identifier,
656    dialect: DialectType = None,
657    is_table: bool = False,
658    normalize: t.Optional[bool] = True,
659) -> exp.Identifier:
660    if isinstance(identifier, str):
661        identifier = exp.parse_identifier(identifier, dialect=dialect)
662
663    if not normalize:
664        return identifier
665
666    # this is used for normalize_identifier, bigquery has special rules pertaining tables
667    identifier.meta["is_table"] = is_table
668    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
669
670
671def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
672    if isinstance(schema, Schema):
673        return schema
674
675    return MappingSchema(schema, **kwargs)
676
677
678def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
679    if mapping is None:
680        return {}
681    elif isinstance(mapping, dict):
682        return mapping
683    elif isinstance(mapping, str):
684        col_name_type_strs = [x.strip() for x in mapping.split(",")]
685        return {
686            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
687            for name_type_str in col_name_type_strs
688        }
689    elif isinstance(mapping, list):
690        return {x.strip(): None for x in mapping}
691
692    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
693
694
695def flatten_schema(
696    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
697) -> t.List[t.List[str]]:
698    tables = []
699    keys = keys or []
700    depth = dict_depth(schema) - 1 if depth is None else depth
701
702    for k, v in schema.items():
703        if depth == 1 or not isinstance(v, dict):
704            tables.append(keys + [k])
705        elif depth >= 2:
706            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
707
708    return tables
709
710
711def nested_get(
712    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
713) -> t.Optional[t.Any]:
714    """
715    Get a value for a nested dictionary.
716
717    Args:
718        d: the dictionary to search.
719        *path: tuples of (name, key), where:
720            `key` is the key in the dictionary to get.
721            `name` is a string to use in the error if `key` isn't found.
722
723    Returns:
724        The value or None if it doesn't exist.
725    """
726    for name, key in path:
727        d = d.get(key)  # type: ignore
728        if d is None:
729            if raise_on_missing:
730                name = "table" if name == "this" else name
731                raise ValueError(f"Unknown {name}: {key}")
732            return None
733
734    return d
735
736
737def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
738    """
739    In-place set a value for a nested dictionary
740
741    Example:
742        >>> nested_set({}, ["top_key", "second_key"], "value")
743        {'top_key': {'second_key': 'value'}}
744
745        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
746        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
747
748    Args:
749        d: dictionary to update.
750        keys: the keys that makeup the path to `value`.
751        value: the value to set in the dictionary for the given key path.
752
753    Returns:
754        The (possibly) updated dictionary.
755    """
756    if not keys:
757        return d
758
759    if len(keys) == 1:
760        d[keys[0]] = value
761        return d
762
763    subd = d
764    for key in keys[:-1]:
765        if key not in subd:
766            subd = subd.setdefault(key, {})
767        else:
768            subd = subd[key]
769
770    subd[keys[-1]] = value
771    return d
class Schema(abc.ABC):
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    @property
 23    def dialect(self) -> t.Optional[Dialect]:
 24        """
 25        Returns None by default. Subclasses that require dialect-specific
 26        behavior should override this property.
 27        """
 28        return None
 29
 30    @abc.abstractmethod
 31    def add_table(
 32        self,
 33        table: exp.Table | str,
 34        column_mapping: t.Optional[ColumnMapping] = None,
 35        dialect: DialectType = None,
 36        normalize: t.Optional[bool] = None,
 37        match_depth: bool = True,
 38    ) -> None:
 39        """
 40        Register or update a table. Some implementing classes may require column information to also be provided.
 41        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 42
 43        Args:
 44            table: the `Table` expression instance or string representing the table.
 45            column_mapping: a column mapping that describes the structure of the table.
 46            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 47            normalize: whether to normalize identifiers according to the dialect of interest.
 48            match_depth: whether to enforce that the table must match the schema's depth or not.
 49        """
 50
 51    @abc.abstractmethod
 52    def column_names(
 53        self,
 54        table: exp.Table | str,
 55        only_visible: bool = False,
 56        dialect: DialectType = None,
 57        normalize: t.Optional[bool] = None,
 58    ) -> t.Sequence[str]:
 59        """
 60        Get the column names for a table.
 61
 62        Args:
 63            table: the `Table` expression instance.
 64            only_visible: whether to include invisible columns.
 65            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 66            normalize: whether to normalize identifiers according to the dialect of interest.
 67
 68        Returns:
 69            The sequence of column names.
 70        """
 71
 72    @abc.abstractmethod
 73    def get_column_type(
 74        self,
 75        table: exp.Table | str,
 76        column: exp.Column | str,
 77        dialect: DialectType = None,
 78        normalize: t.Optional[bool] = None,
 79    ) -> exp.DataType:
 80        """
 81        Get the `sqlglot.exp.DataType` type of a column in the schema.
 82
 83        Args:
 84            table: the source table.
 85            column: the target column.
 86            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 87            normalize: whether to normalize identifiers according to the dialect of interest.
 88
 89        Returns:
 90            The resulting column type.
 91        """
 92
 93    def has_column(
 94        self,
 95        table: exp.Table | str,
 96        column: exp.Column | str,
 97        dialect: DialectType = None,
 98        normalize: t.Optional[bool] = None,
 99    ) -> bool:
100        """
101        Returns whether `column` appears in `table`'s schema.
102
103        Args:
104            table: the source table.
105            column: the target column.
106            dialect: the SQL dialect that will be used to parse `table` if it's a string.
107            normalize: whether to normalize identifiers according to the dialect of interest.
108
109        Returns:
110            True if the column appears in the schema, False otherwise.
111        """
112        name = column if isinstance(column, str) else column.name
113        return name in self.column_names(table, dialect=dialect, normalize=normalize)
114
115    def get_udf_type(
116        self,
117        udf: exp.Anonymous | str,
118        dialect: DialectType = None,
119        normalize: t.Optional[bool] = None,
120    ) -> exp.DataType:
121        """
122        Get the return type of a UDF.
123
124        Args:
125            udf: the UDF expression or string.
126            dialect: the SQL dialect for parsing string arguments.
127            normalize: whether to normalize identifiers.
128
129        Returns:
130            The return type as a DataType, or UNKNOWN if not found.
131        """
132        return exp.DataType.build("unknown")
133
134    @property
135    @abc.abstractmethod
136    def supported_table_args(self) -> t.Tuple[str, ...]:
137        """
138        Table arguments this schema support, e.g. `("this", "db", "catalog")`
139        """
140
141    @property
142    def empty(self) -> bool:
143        """Returns whether the schema is empty."""
144        return True

Abstract base class for database schemas

dialect: Optional[sqlglot.dialects.Dialect]
22    @property
23    def dialect(self) -> t.Optional[Dialect]:
24        """
25        Returns None by default. Subclasses that require dialect-specific
26        behavior should override this property.
27        """
28        return None

Returns None by default. Subclasses that require dialect-specific behavior should override this property.

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
30    @abc.abstractmethod
31    def add_table(
32        self,
33        table: exp.Table | str,
34        column_mapping: t.Optional[ColumnMapping] = None,
35        dialect: DialectType = None,
36        normalize: t.Optional[bool] = None,
37        match_depth: bool = True,
38    ) -> None:
39        """
40        Register or update a table. Some implementing classes may require column information to also be provided.
41        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
42
43        Args:
44            table: the `Table` expression instance or string representing the table.
45            column_mapping: a column mapping that describes the structure of the table.
46            dialect: the SQL dialect that will be used to parse `table` if it's a string.
47            normalize: whether to normalize identifiers according to the dialect of interest.
48            match_depth: whether to enforce that the table must match the schema's depth or not.
49        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
51    @abc.abstractmethod
52    def column_names(
53        self,
54        table: exp.Table | str,
55        only_visible: bool = False,
56        dialect: DialectType = None,
57        normalize: t.Optional[bool] = None,
58    ) -> t.Sequence[str]:
59        """
60        Get the column names for a table.
61
62        Args:
63            table: the `Table` expression instance.
64            only_visible: whether to include invisible columns.
65            dialect: the SQL dialect that will be used to parse `table` if it's a string.
66            normalize: whether to normalize identifiers according to the dialect of interest.
67
68        Returns:
69            The sequence of column names.
70        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
72    @abc.abstractmethod
73    def get_column_type(
74        self,
75        table: exp.Table | str,
76        column: exp.Column | str,
77        dialect: DialectType = None,
78        normalize: t.Optional[bool] = None,
79    ) -> exp.DataType:
80        """
81        Get the `sqlglot.exp.DataType` type of a column in the schema.
82
83        Args:
84            table: the source table.
85            column: the target column.
86            dialect: the SQL dialect that will be used to parse `table` if it's a string.
87            normalize: whether to normalize identifiers according to the dialect of interest.
88
89        Returns:
90            The resulting column type.
91        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 93    def has_column(
 94        self,
 95        table: exp.Table | str,
 96        column: exp.Column | str,
 97        dialect: DialectType = None,
 98        normalize: t.Optional[bool] = None,
 99    ) -> bool:
100        """
101        Returns whether `column` appears in `table`'s schema.
102
103        Args:
104            table: the source table.
105            column: the target column.
106            dialect: the SQL dialect that will be used to parse `table` if it's a string.
107            normalize: whether to normalize identifiers according to the dialect of interest.
108
109        Returns:
110            True if the column appears in the schema, False otherwise.
111        """
112        name = column if isinstance(column, str) else column.name
113        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def get_udf_type( self, udf: sqlglot.expressions.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
115    def get_udf_type(
116        self,
117        udf: exp.Anonymous | str,
118        dialect: DialectType = None,
119        normalize: t.Optional[bool] = None,
120    ) -> exp.DataType:
121        """
122        Get the return type of a UDF.
123
124        Args:
125            udf: the UDF expression or string.
126            dialect: the SQL dialect for parsing string arguments.
127            normalize: whether to normalize identifiers.
128
129        Returns:
130            The return type as a DataType, or UNKNOWN if not found.
131        """
132        return exp.DataType.build("unknown")

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string.
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

supported_table_args: Tuple[str, ...]
134    @property
135    @abc.abstractmethod
136    def supported_table_args(self) -> t.Tuple[str, ...]:
137        """
138        Table arguments this schema support, e.g. `("this", "db", "catalog")`
139        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
141    @property
142    def empty(self) -> bool:
143        """Returns whether the schema is empty."""
144        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
147class AbstractMappingSchema:
148    def __init__(
149        self,
150        mapping: t.Optional[t.Dict] = None,
151        udf_mapping: t.Optional[t.Dict] = None,
152    ) -> None:
153        self.mapping = mapping or {}
154        self.mapping_trie = new_trie(
155            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
156        )
157
158        self.udf_mapping = udf_mapping or {}
159        self.udf_trie = new_trie(
160            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
161        )
162
163        self._supported_table_args: t.Tuple[str, ...] = tuple()
164
165    @property
166    def empty(self) -> bool:
167        return not self.mapping
168
169    def depth(self) -> int:
170        return dict_depth(self.mapping)
171
172    def udf_depth(self) -> int:
173        return dict_depth(self.udf_mapping)
174
175    @property
176    def supported_table_args(self) -> t.Tuple[str, ...]:
177        if not self._supported_table_args and self.mapping:
178            depth = self.depth()
179
180            if not depth:  # None
181                self._supported_table_args = tuple()
182            elif 1 <= depth <= 3:
183                self._supported_table_args = exp.TABLE_PARTS[:depth]
184            else:
185                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
186
187        return self._supported_table_args
188
189    def table_parts(self, table: exp.Table) -> t.List[str]:
190        return [p.name for p in reversed(table.parts)]
191
192    def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
193        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
194        parent = udf.parent
195        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
196        return list(reversed(parts))[0 : self.udf_depth()]
197
198    def _find_in_trie(
199        self,
200        parts: t.List[str],
201        trie: t.Dict,
202        raise_on_missing: bool,
203    ) -> t.Optional[t.List[str]]:
204        value, trie = in_trie(trie, parts)
205
206        if value == TrieResult.FAILED:
207            return None
208
209        if value == TrieResult.PREFIX:
210            possibilities = flatten_schema(trie)
211
212            if len(possibilities) == 1:
213                parts.extend(possibilities[0])
214            else:
215                if raise_on_missing:
216                    joined_parts = ".".join(parts)
217                    message = ", ".join(".".join(p) for p in possibilities)
218                    raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
219
220                return None
221
222        return parts
223
224    def find(
225        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
226    ) -> t.Optional[t.Any]:
227        """
228        Returns the schema of a given table.
229
230        Args:
231            table: the target table.
232            raise_on_missing: whether to raise in case the schema is not found.
233            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
234
235        Returns:
236            The schema of the target table.
237        """
238        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
239        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
240
241        if resolved_parts is None:
242            return None
243
244        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
245
246    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
247        """
248        Returns the return type of a given UDF.
249
250        Args:
251            udf: the target UDF expression.
252            raise_on_missing: whether to raise if the UDF is not found.
253
254        Returns:
255            The return type of the UDF, or None if not found.
256        """
257        parts = self.udf_parts(udf)
258        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
259
260        if resolved_parts is None:
261            return None
262
263        return nested_get(
264            self.udf_mapping,
265            *zip(resolved_parts, reversed(resolved_parts)),
266            raise_on_missing=raise_on_missing,
267        )
268
269    def nested_get(
270        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
271    ) -> t.Optional[t.Any]:
272        return nested_get(
273            d or self.mapping,
274            *zip(self.supported_table_args, reversed(parts)),
275            raise_on_missing=raise_on_missing,
276        )
AbstractMappingSchema(mapping: Optional[Dict] = None, udf_mapping: Optional[Dict] = None)
148    def __init__(
149        self,
150        mapping: t.Optional[t.Dict] = None,
151        udf_mapping: t.Optional[t.Dict] = None,
152    ) -> None:
153        self.mapping = mapping or {}
154        self.mapping_trie = new_trie(
155            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
156        )
157
158        self.udf_mapping = udf_mapping or {}
159        self.udf_trie = new_trie(
160            tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
161        )
162
163        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
udf_mapping
udf_trie
empty: bool
165    @property
166    def empty(self) -> bool:
167        return not self.mapping
def depth(self) -> int:
169    def depth(self) -> int:
170        return dict_depth(self.mapping)
def udf_depth(self) -> int:
172    def udf_depth(self) -> int:
173        return dict_depth(self.udf_mapping)
supported_table_args: Tuple[str, ...]
175    @property
176    def supported_table_args(self) -> t.Tuple[str, ...]:
177        if not self._supported_table_args and self.mapping:
178            depth = self.depth()
179
180            if not depth:  # None
181                self._supported_table_args = tuple()
182            elif 1 <= depth <= 3:
183                self._supported_table_args = exp.TABLE_PARTS[:depth]
184            else:
185                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
186
187        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
189    def table_parts(self, table: exp.Table) -> t.List[str]:
190        return [p.name for p in reversed(table.parts)]
def udf_parts(self, udf: sqlglot.expressions.Anonymous) -> List[str]:
192    def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
193        # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
194        parent = udf.parent
195        parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
196        return list(reversed(parts))[0 : self.udf_depth()]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
224    def find(
225        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
226    ) -> t.Optional[t.Any]:
227        """
228        Returns the schema of a given table.
229
230        Args:
231            table: the target table.
232            raise_on_missing: whether to raise in case the schema is not found.
233            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
234
235        Returns:
236            The schema of the target table.
237        """
238        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
239        resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
240
241        if resolved_parts is None:
242            return None
243
244        return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def find_udf( self, udf: sqlglot.expressions.Anonymous, raise_on_missing: bool = False) -> Optional[Any]:
246    def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
247        """
248        Returns the return type of a given UDF.
249
250        Args:
251            udf: the target UDF expression.
252            raise_on_missing: whether to raise if the UDF is not found.
253
254        Returns:
255            The return type of the UDF, or None if not found.
256        """
257        parts = self.udf_parts(udf)
258        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
259
260        if resolved_parts is None:
261            return None
262
263        return nested_get(
264            self.udf_mapping,
265            *zip(resolved_parts, reversed(resolved_parts)),
266            raise_on_missing=raise_on_missing,
267        )

Returns the return type of a given UDF.

Arguments:
  • udf: the target UDF expression.
  • raise_on_missing: whether to raise if the UDF is not found.
Returns:

The return type of the UDF, or None if not found.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
269    def nested_get(
270        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
271    ) -> t.Optional[t.Any]:
272        return nested_get(
273            d or self.mapping,
274            *zip(self.supported_table_args, reversed(parts)),
275            raise_on_missing=raise_on_missing,
276        )
class MappingSchema(AbstractMappingSchema, Schema):
279class MappingSchema(AbstractMappingSchema, Schema):
280    """
281    Schema based on a nested mapping.
282
283    Args:
284        schema: Mapping in one of the following forms:
285            1. {table: {col: type}}
286            2. {db: {table: {col: type}}}
287            3. {catalog: {db: {table: {col: type}}}}
288            4. None - Tables will be added later
289        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
290            are assumed to be visible. The nesting should mirror that of the schema:
291            1. {table: set(*cols)}}
292            2. {db: {table: set(*cols)}}}
293            3. {catalog: {db: {table: set(*cols)}}}}
294        dialect: The dialect to be used for custom type mappings & parsing string arguments.
295        normalize: Whether to normalize identifier names according to the given dialect or not.
296    """
297
298    def __init__(
299        self,
300        schema: t.Optional[t.Dict] = None,
301        visible: t.Optional[t.Dict] = None,
302        dialect: DialectType = None,
303        normalize: bool = True,
304        udf_mapping: t.Optional[t.Dict] = None,
305    ) -> None:
306        self.visible = {} if visible is None else visible
307        self.normalize = normalize
308        self._dialect = Dialect.get_or_raise(dialect)
309        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
310        self._depth = 0
311        schema = {} if schema is None else schema
312        udf_mapping = {} if udf_mapping is None else udf_mapping
313
314        super().__init__(
315            self._normalize(schema) if self.normalize else schema,
316            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
317        )
318
319    @property
320    def dialect(self) -> Dialect:
321        """Returns the dialect for this mapping schema."""
322        return self._dialect
323
324    @classmethod
325    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
326        return MappingSchema(
327            schema=mapping_schema.mapping,
328            visible=mapping_schema.visible,
329            dialect=mapping_schema.dialect,
330            normalize=mapping_schema.normalize,
331            udf_mapping=mapping_schema.udf_mapping,
332        )
333
334    def find(
335        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
336    ) -> t.Optional[t.Any]:
337        schema = super().find(
338            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
339        )
340        if ensure_data_types and isinstance(schema, dict):
341            schema = {
342                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
343                for col, dtype in schema.items()
344            }
345
346        return schema
347
348    def copy(self, **kwargs) -> MappingSchema:
349        return MappingSchema(
350            **{  # type: ignore
351                "schema": self.mapping.copy(),
352                "visible": self.visible.copy(),
353                "dialect": self.dialect,
354                "normalize": self.normalize,
355                "udf_mapping": self.udf_mapping.copy(),
356                **kwargs,
357            }
358        )
359
360    def add_table(
361        self,
362        table: exp.Table | str,
363        column_mapping: t.Optional[ColumnMapping] = None,
364        dialect: DialectType = None,
365        normalize: t.Optional[bool] = None,
366        match_depth: bool = True,
367    ) -> None:
368        """
369        Register or update a table. Updates are only performed if a new column mapping is provided.
370        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
371
372        Args:
373            table: the `Table` expression instance or string representing the table.
374            column_mapping: a column mapping that describes the structure of the table.
375            dialect: the SQL dialect that will be used to parse `table` if it's a string.
376            normalize: whether to normalize identifiers according to the dialect of interest.
377            match_depth: whether to enforce that the table must match the schema's depth or not.
378        """
379        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
380
381        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
382            raise SchemaError(
383                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
384                f"schema's nesting level: {self.depth()}."
385            )
386
387        normalized_column_mapping = {
388            self._normalize_name(key, dialect=dialect, normalize=normalize): value
389            for key, value in ensure_column_mapping(column_mapping).items()
390        }
391
392        schema = self.find(normalized_table, raise_on_missing=False)
393        if schema and not normalized_column_mapping:
394            return
395
396        parts = self.table_parts(normalized_table)
397
398        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
399        new_trie([parts], self.mapping_trie)
400
401    def column_names(
402        self,
403        table: exp.Table | str,
404        only_visible: bool = False,
405        dialect: DialectType = None,
406        normalize: t.Optional[bool] = None,
407    ) -> t.List[str]:
408        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
409
410        schema = self.find(normalized_table)
411        if schema is None:
412            return []
413
414        if not only_visible or not self.visible:
415            return list(schema)
416
417        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
418        return [col for col in schema if col in visible]
419
420    def get_column_type(
421        self,
422        table: exp.Table | str,
423        column: exp.Column | str,
424        dialect: DialectType = None,
425        normalize: t.Optional[bool] = None,
426    ) -> exp.DataType:
427        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
428
429        normalized_column_name = self._normalize_name(
430            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
431        )
432
433        table_schema = self.find(normalized_table, raise_on_missing=False)
434        if table_schema:
435            column_type = table_schema.get(normalized_column_name)
436
437            if isinstance(column_type, exp.DataType):
438                return column_type
439            elif isinstance(column_type, str):
440                return self._to_data_type(column_type, dialect=dialect)
441
442        return exp.DataType.build("unknown")
443
444    def get_udf_type(
445        self,
446        udf: exp.Anonymous | str,
447        dialect: DialectType = None,
448        normalize: t.Optional[bool] = None,
449    ) -> exp.DataType:
450        """
451        Get the return type of a UDF.
452
453        Args:
454            udf: the UDF expression or string (e.g., "db.my_func()").
455            dialect: the SQL dialect for parsing string arguments.
456            normalize: whether to normalize identifiers.
457
458        Returns:
459            The return type as a DataType, or UNKNOWN if not found.
460        """
461        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
462        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
463
464        if resolved_parts is None:
465            return exp.DataType.build("unknown")
466
467        udf_type = nested_get(
468            self.udf_mapping,
469            *zip(resolved_parts, reversed(resolved_parts)),
470            raise_on_missing=False,
471        )
472
473        if isinstance(udf_type, exp.DataType):
474            return udf_type
475        elif isinstance(udf_type, str):
476            return self._to_data_type(udf_type, dialect=dialect)
477
478        return exp.DataType.build("unknown")
479
480    def has_column(
481        self,
482        table: exp.Table | str,
483        column: exp.Column | str,
484        dialect: DialectType = None,
485        normalize: t.Optional[bool] = None,
486    ) -> bool:
487        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
488
489        normalized_column_name = self._normalize_name(
490            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
491        )
492
493        table_schema = self.find(normalized_table, raise_on_missing=False)
494        return normalized_column_name in table_schema if table_schema else False
495
496    def _normalize(self, schema: t.Dict) -> t.Dict:
497        """
498        Normalizes all identifiers in the schema.
499
500        Args:
501            schema: the schema to normalize.
502
503        Returns:
504            The normalized schema mapping.
505        """
506        normalized_mapping: t.Dict = {}
507        flattened_schema = flatten_schema(schema)
508        error_msg = "Table {} must match the schema's nesting level: {}."
509
510        for keys in flattened_schema:
511            columns = nested_get(schema, *zip(keys, keys))
512
513            if not isinstance(columns, dict):
514                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
515            if not columns:
516                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
517            if isinstance(first(columns.values()), dict):
518                raise SchemaError(
519                    error_msg.format(
520                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
521                    ),
522                )
523
524            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
525            for column_name, column_type in columns.items():
526                nested_set(
527                    normalized_mapping,
528                    normalized_keys + [self._normalize_name(column_name)],
529                    column_type,
530                )
531
532        return normalized_mapping
533
534    def _normalize_udfs(self, udfs: t.Dict) -> t.Dict:
535        """
536        Normalizes all identifiers in the UDF mapping.
537
538        Args:
539            udfs: the UDF mapping to normalize.
540
541        Returns:
542            The normalized UDF mapping.
543        """
544        normalized_mapping: t.Dict = {}
545
546        for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
547            udf_type = nested_get(udfs, *zip(keys, keys))
548            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
549            nested_set(normalized_mapping, normalized_keys, udf_type)
550
551        return normalized_mapping
552
553    def _normalize_udf(
554        self,
555        udf: exp.Anonymous | str,
556        dialect: DialectType = None,
557        normalize: t.Optional[bool] = None,
558    ) -> t.List[str]:
559        """
560        Extract and normalize UDF parts for lookup.
561
562        Args:
563            udf: the UDF expression or qualified string (e.g., "db.my_func()").
564            dialect: the SQL dialect for parsing.
565            normalize: whether to normalize identifiers.
566
567        Returns:
568            A list of normalized UDF parts (reversed for trie lookup).
569        """
570        dialect = dialect or self.dialect
571        normalize = self.normalize if normalize is None else normalize
572
573        if isinstance(udf, str):
574            parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect)
575
576            if isinstance(parsed, exp.Anonymous):
577                udf = parsed
578            elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
579                udf = parsed.expression
580            else:
581                raise SchemaError(f"Unable to parse UDF from: {udf!r}")
582        parts = self.udf_parts(udf)
583
584        if normalize:
585            parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
586
587        return parts
588
589    def _normalize_table(
590        self,
591        table: exp.Table | str,
592        dialect: DialectType = None,
593        normalize: t.Optional[bool] = None,
594    ) -> exp.Table:
595        dialect = dialect or self.dialect
596        normalize = self.normalize if normalize is None else normalize
597
598        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
599
600        if normalize:
601            for part in normalized_table.parts:
602                if isinstance(part, exp.Identifier):
603                    part.replace(
604                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
605                    )
606
607        return normalized_table
608
609    def _normalize_name(
610        self,
611        name: str | exp.Identifier,
612        dialect: DialectType = None,
613        is_table: bool = False,
614        normalize: t.Optional[bool] = None,
615    ) -> str:
616        return normalize_name(
617            name,
618            dialect=dialect or self.dialect,
619            is_table=is_table,
620            normalize=self.normalize if normalize is None else normalize,
621        ).name
622
623    def depth(self) -> int:
624        if not self.empty and not self._depth:
625            # The columns themselves are a mapping, but we don't want to include those
626            self._depth = super().depth() - 1
627        return self._depth
628
629    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
630        """
631        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
632
633        Args:
634            schema_type: the type we want to convert.
635            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
636
637        Returns:
638            The resulting expression type.
639        """
640        if schema_type not in self._type_mapping_cache:
641            dialect = Dialect.get_or_raise(dialect) if dialect else self.dialect
642            udt = dialect.SUPPORTS_USER_DEFINED_TYPES
643
644            try:
645                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
646                expression.transform(dialect.normalize_identifier, copy=False)
647                self._type_mapping_cache[schema_type] = expression
648            except AttributeError:
649                in_dialect = f" in dialect {dialect}" if dialect else ""
650                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
651
652        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: bool = True, udf_mapping: Optional[Dict] = None)
298    def __init__(
299        self,
300        schema: t.Optional[t.Dict] = None,
301        visible: t.Optional[t.Dict] = None,
302        dialect: DialectType = None,
303        normalize: bool = True,
304        udf_mapping: t.Optional[t.Dict] = None,
305    ) -> None:
306        self.visible = {} if visible is None else visible
307        self.normalize = normalize
308        self._dialect = Dialect.get_or_raise(dialect)
309        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
310        self._depth = 0
311        schema = {} if schema is None else schema
312        udf_mapping = {} if udf_mapping is None else udf_mapping
313
314        super().__init__(
315            self._normalize(schema) if self.normalize else schema,
316            self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
317        )
visible
normalize
dialect: sqlglot.dialects.Dialect
319    @property
320    def dialect(self) -> Dialect:
321        """Returns the dialect for this mapping schema."""
322        return self._dialect

Returns the dialect for this mapping schema.

@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
324    @classmethod
325    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
326        return MappingSchema(
327            schema=mapping_schema.mapping,
328            visible=mapping_schema.visible,
329            dialect=mapping_schema.dialect,
330            normalize=mapping_schema.normalize,
331            udf_mapping=mapping_schema.udf_mapping,
332        )
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
334    def find(
335        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
336    ) -> t.Optional[t.Any]:
337        schema = super().find(
338            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
339        )
340        if ensure_data_types and isinstance(schema, dict):
341            schema = {
342                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
343                for col, dtype in schema.items()
344            }
345
346        return schema

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def copy(self, **kwargs) -> MappingSchema:
348    def copy(self, **kwargs) -> MappingSchema:
349        return MappingSchema(
350            **{  # type: ignore
351                "schema": self.mapping.copy(),
352                "visible": self.visible.copy(),
353                "dialect": self.dialect,
354                "normalize": self.normalize,
355                "udf_mapping": self.udf_mapping.copy(),
356                **kwargs,
357            }
358        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
360    def add_table(
361        self,
362        table: exp.Table | str,
363        column_mapping: t.Optional[ColumnMapping] = None,
364        dialect: DialectType = None,
365        normalize: t.Optional[bool] = None,
366        match_depth: bool = True,
367    ) -> None:
368        """
369        Register or update a table. Updates are only performed if a new column mapping is provided.
370        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
371
372        Args:
373            table: the `Table` expression instance or string representing the table.
374            column_mapping: a column mapping that describes the structure of the table.
375            dialect: the SQL dialect that will be used to parse `table` if it's a string.
376            normalize: whether to normalize identifiers according to the dialect of interest.
377            match_depth: whether to enforce that the table must match the schema's depth or not.
378        """
379        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
380
381        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
382            raise SchemaError(
383                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
384                f"schema's nesting level: {self.depth()}."
385            )
386
387        normalized_column_mapping = {
388            self._normalize_name(key, dialect=dialect, normalize=normalize): value
389            for key, value in ensure_column_mapping(column_mapping).items()
390        }
391
392        schema = self.find(normalized_table, raise_on_missing=False)
393        if schema and not normalized_column_mapping:
394            return
395
396        parts = self.table_parts(normalized_table)
397
398        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
399        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
401    def column_names(
402        self,
403        table: exp.Table | str,
404        only_visible: bool = False,
405        dialect: DialectType = None,
406        normalize: t.Optional[bool] = None,
407    ) -> t.List[str]:
408        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
409
410        schema = self.find(normalized_table)
411        if schema is None:
412            return []
413
414        if not only_visible or not self.visible:
415            return list(schema)
416
417        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
418        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
420    def get_column_type(
421        self,
422        table: exp.Table | str,
423        column: exp.Column | str,
424        dialect: DialectType = None,
425        normalize: t.Optional[bool] = None,
426    ) -> exp.DataType:
427        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
428
429        normalized_column_name = self._normalize_name(
430            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
431        )
432
433        table_schema = self.find(normalized_table, raise_on_missing=False)
434        if table_schema:
435            column_type = table_schema.get(normalized_column_name)
436
437            if isinstance(column_type, exp.DataType):
438                return column_type
439            elif isinstance(column_type, str):
440                return self._to_data_type(column_type, dialect=dialect)
441
442        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def get_udf_type( self, udf: sqlglot.expressions.Anonymous | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
444    def get_udf_type(
445        self,
446        udf: exp.Anonymous | str,
447        dialect: DialectType = None,
448        normalize: t.Optional[bool] = None,
449    ) -> exp.DataType:
450        """
451        Get the return type of a UDF.
452
453        Args:
454            udf: the UDF expression or string (e.g., "db.my_func()").
455            dialect: the SQL dialect for parsing string arguments.
456            normalize: whether to normalize identifiers.
457
458        Returns:
459            The return type as a DataType, or UNKNOWN if not found.
460        """
461        parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
462        resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
463
464        if resolved_parts is None:
465            return exp.DataType.build("unknown")
466
467        udf_type = nested_get(
468            self.udf_mapping,
469            *zip(resolved_parts, reversed(resolved_parts)),
470            raise_on_missing=False,
471        )
472
473        if isinstance(udf_type, exp.DataType):
474            return udf_type
475        elif isinstance(udf_type, str):
476            return self._to_data_type(udf_type, dialect=dialect)
477
478        return exp.DataType.build("unknown")

Get the return type of a UDF.

Arguments:
  • udf: the UDF expression or string (e.g., "db.my_func()").
  • dialect: the SQL dialect for parsing string arguments.
  • normalize: whether to normalize identifiers.
Returns:

The return type as a DataType, or UNKNOWN if not found.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
480    def has_column(
481        self,
482        table: exp.Table | str,
483        column: exp.Column | str,
484        dialect: DialectType = None,
485        normalize: t.Optional[bool] = None,
486    ) -> bool:
487        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
488
489        normalized_column_name = self._normalize_name(
490            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
491        )
492
493        table_schema = self.find(normalized_table, raise_on_missing=False)
494        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
623    def depth(self) -> int:
624        if not self.empty and not self._depth:
625            # The columns themselves are a mapping, but we don't want to include those
626            self._depth = super().depth() - 1
627        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
655def normalize_name(
656    identifier: str | exp.Identifier,
657    dialect: DialectType = None,
658    is_table: bool = False,
659    normalize: t.Optional[bool] = True,
660) -> exp.Identifier:
661    if isinstance(identifier, str):
662        identifier = exp.parse_identifier(identifier, dialect=dialect)
663
664    if not normalize:
665        return identifier
666
667    # this is used for normalize_identifier, bigquery has special rules pertaining tables
668    identifier.meta["is_table"] = is_table
669    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
672def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
673    if isinstance(schema, Schema):
674        return schema
675
676    return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: Union[Dict, str, List, NoneType]) -> Dict:
679def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
680    if mapping is None:
681        return {}
682    elif isinstance(mapping, dict):
683        return mapping
684    elif isinstance(mapping, str):
685        col_name_type_strs = [x.strip() for x in mapping.split(",")]
686        return {
687            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
688            for name_type_str in col_name_type_strs
689        }
690    elif isinstance(mapping, list):
691        return {x.strip(): None for x in mapping}
692
693    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: Optional[int] = None, keys: Optional[List[str]] = None) -> List[List[str]]:
696def flatten_schema(
697    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
698) -> t.List[t.List[str]]:
699    tables = []
700    keys = keys or []
701    depth = dict_depth(schema) - 1 if depth is None else depth
702
703    for k, v in schema.items():
704        if depth == 1 or not isinstance(v, dict):
705            tables.append(keys + [k])
706        elif depth >= 2:
707            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
708
709    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
712def nested_get(
713    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
714) -> t.Optional[t.Any]:
715    """
716    Get a value for a nested dictionary.
717
718    Args:
719        d: the dictionary to search.
720        *path: tuples of (name, key), where:
721            `key` is the key in the dictionary to get.
722            `name` is a string to use in the error if `key` isn't found.
723
724    Returns:
725        The value or None if it doesn't exist.
726    """
727    for name, key in path:
728        d = d.get(key)  # type: ignore
729        if d is None:
730            if raise_on_missing:
731                name = "table" if name == "this" else name
732                raise ValueError(f"Unknown {name}: {key}")
733            return None
734
735    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
738def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
739    """
740    In-place set a value for a nested dictionary
741
742    Example:
743        >>> nested_set({}, ["top_key", "second_key"], "value")
744        {'top_key': {'second_key': 'value'}}
745
746        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
747        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
748
749    Args:
750        d: dictionary to update.
751        keys: the keys that makeup the path to `value`.
752        value: the value to set in the dictionary for the given key path.
753
754    Returns:
755        The (possibly) updated dictionary.
756    """
757    if not keys:
758        return d
759
760    if len(keys) == 1:
761        d[keys[0]] = value
762        return d
763
764    subd = d
765    for key in keys[:-1]:
766        if key not in subd:
767            subd = subd.setdefault(key, {})
768        else:
769            subd = subd[key]
770
771    subd[keys[-1]] = value
772    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.