Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15    ColumnMapping = t.Union[t.Dict, str, t.List]
 16
 17
 18class Schema(abc.ABC):
 19    """Abstract base class for database schemas"""
 20
 21    dialect: DialectType
 22
 23    @abc.abstractmethod
 24    def add_table(
 25        self,
 26        table: exp.Table | str,
 27        column_mapping: t.Optional[ColumnMapping] = None,
 28        dialect: DialectType = None,
 29        normalize: t.Optional[bool] = None,
 30        match_depth: bool = True,
 31    ) -> None:
 32        """
 33        Register or update a table. Some implementing classes may require column information to also be provided.
 34        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 35
 36        Args:
 37            table: the `Table` expression instance or string representing the table.
 38            column_mapping: a column mapping that describes the structure of the table.
 39            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 40            normalize: whether to normalize identifiers according to the dialect of interest.
 41            match_depth: whether to enforce that the table must match the schema's depth or not.
 42        """
 43
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50        normalize: t.Optional[bool] = None,
 51    ) -> t.Sequence[str]:
 52        """
 53        Get the column names for a table.
 54
 55        Args:
 56            table: the `Table` expression instance.
 57            only_visible: whether to include invisible columns.
 58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 59            normalize: whether to normalize identifiers according to the dialect of interest.
 60
 61        Returns:
 62            The sequence of column names.
 63        """
 64
 65    @abc.abstractmethod
 66    def get_column_type(
 67        self,
 68        table: exp.Table | str,
 69        column: exp.Column | str,
 70        dialect: DialectType = None,
 71        normalize: t.Optional[bool] = None,
 72    ) -> exp.DataType:
 73        """
 74        Get the `sqlglot.exp.DataType` type of a column in the schema.
 75
 76        Args:
 77            table: the source table.
 78            column: the target column.
 79            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 80            normalize: whether to normalize identifiers according to the dialect of interest.
 81
 82        Returns:
 83            The resulting column type.
 84        """
 85
 86    def has_column(
 87        self,
 88        table: exp.Table | str,
 89        column: exp.Column | str,
 90        dialect: DialectType = None,
 91        normalize: t.Optional[bool] = None,
 92    ) -> bool:
 93        """
 94        Returns whether `column` appears in `table`'s schema.
 95
 96        Args:
 97            table: the source table.
 98            column: the target column.
 99            dialect: the SQL dialect that will be used to parse `table` if it's a string.
100            normalize: whether to normalize identifiers according to the dialect of interest.
101
102        Returns:
103            True if the column appears in the schema, False otherwise.
104        """
105        name = column if isinstance(column, str) else column.name
106        return name in self.column_names(table, dialect=dialect, normalize=normalize)
107
108    @property
109    @abc.abstractmethod
110    def supported_table_args(self) -> t.Tuple[str, ...]:
111        """
112        Table arguments this schema support, e.g. `("this", "db", "catalog")`
113        """
114
115    @property
116    def empty(self) -> bool:
117        """Returns whether the schema is empty."""
118        return True
119
120
121class AbstractMappingSchema:
122    def __init__(
123        self,
124        mapping: t.Optional[t.Dict] = None,
125    ) -> None:
126        self.mapping = mapping or {}
127        self.mapping_trie = new_trie(
128            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
129        )
130        self._supported_table_args: t.Tuple[str, ...] = tuple()
131
132    @property
133    def empty(self) -> bool:
134        return not self.mapping
135
136    def depth(self) -> int:
137        return dict_depth(self.mapping)
138
139    @property
140    def supported_table_args(self) -> t.Tuple[str, ...]:
141        if not self._supported_table_args and self.mapping:
142            depth = self.depth()
143
144            if not depth:  # None
145                self._supported_table_args = tuple()
146            elif 1 <= depth <= 3:
147                self._supported_table_args = exp.TABLE_PARTS[:depth]
148            else:
149                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
150
151        return self._supported_table_args
152
153    def table_parts(self, table: exp.Table) -> t.List[str]:
154        return [part.name for part in reversed(table.parts)]
155
156    def find(
157        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
158    ) -> t.Optional[t.Any]:
159        """
160        Returns the schema of a given table.
161
162        Args:
163            table: the target table.
164            raise_on_missing: whether to raise in case the schema is not found.
165            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
166
167        Returns:
168            The schema of the target table.
169        """
170        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
171        value, trie = in_trie(self.mapping_trie, parts)
172
173        if value == TrieResult.FAILED:
174            return None
175
176        if value == TrieResult.PREFIX:
177            possibilities = flatten_schema(trie)
178
179            if len(possibilities) == 1:
180                parts.extend(possibilities[0])
181            else:
182                message = ", ".join(".".join(parts) for parts in possibilities)
183                if raise_on_missing:
184                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
185                return None
186
187        return self.nested_get(parts, raise_on_missing=raise_on_missing)
188
189    def nested_get(
190        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
191    ) -> t.Optional[t.Any]:
192        return nested_get(
193            d or self.mapping,
194            *zip(self.supported_table_args, reversed(parts)),
195            raise_on_missing=raise_on_missing,
196        )
197
198
199class MappingSchema(AbstractMappingSchema, Schema):
200    """
201    Schema based on a nested mapping.
202
203    Args:
204        schema: Mapping in one of the following forms:
205            1. {table: {col: type}}
206            2. {db: {table: {col: type}}}
207            3. {catalog: {db: {table: {col: type}}}}
208            4. None - Tables will be added later
209        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
210            are assumed to be visible. The nesting should mirror that of the schema:
211            1. {table: set(*cols)}}
212            2. {db: {table: set(*cols)}}}
213            3. {catalog: {db: {table: set(*cols)}}}}
214        dialect: The dialect to be used for custom type mappings & parsing string arguments.
215        normalize: Whether to normalize identifier names according to the given dialect or not.
216    """
217
218    def __init__(
219        self,
220        schema: t.Optional[t.Dict] = None,
221        visible: t.Optional[t.Dict] = None,
222        dialect: DialectType = None,
223        normalize: bool = True,
224    ) -> None:
225        self.dialect = dialect
226        self.visible = {} if visible is None else visible
227        self.normalize = normalize
228        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
229        self._depth = 0
230        schema = {} if schema is None else schema
231
232        super().__init__(self._normalize(schema) if self.normalize else schema)
233
234    @classmethod
235    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
236        return MappingSchema(
237            schema=mapping_schema.mapping,
238            visible=mapping_schema.visible,
239            dialect=mapping_schema.dialect,
240            normalize=mapping_schema.normalize,
241        )
242
243    def find(
244        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
245    ) -> t.Optional[t.Any]:
246        schema = super().find(
247            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
248        )
249        if ensure_data_types and isinstance(schema, dict):
250            schema = {
251                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
252                for col, dtype in schema.items()
253            }
254
255        return schema
256
257    def copy(self, **kwargs) -> MappingSchema:
258        return MappingSchema(
259            **{  # type: ignore
260                "schema": self.mapping.copy(),
261                "visible": self.visible.copy(),
262                "dialect": self.dialect,
263                "normalize": self.normalize,
264                **kwargs,
265            }
266        )
267
268    def add_table(
269        self,
270        table: exp.Table | str,
271        column_mapping: t.Optional[ColumnMapping] = None,
272        dialect: DialectType = None,
273        normalize: t.Optional[bool] = None,
274        match_depth: bool = True,
275    ) -> None:
276        """
277        Register or update a table. Updates are only performed if a new column mapping is provided.
278        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
279
280        Args:
281            table: the `Table` expression instance or string representing the table.
282            column_mapping: a column mapping that describes the structure of the table.
283            dialect: the SQL dialect that will be used to parse `table` if it's a string.
284            normalize: whether to normalize identifiers according to the dialect of interest.
285            match_depth: whether to enforce that the table must match the schema's depth or not.
286        """
287        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
288
289        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
290            raise SchemaError(
291                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
292                f"schema's nesting level: {self.depth()}."
293            )
294
295        normalized_column_mapping = {
296            self._normalize_name(key, dialect=dialect, normalize=normalize): value
297            for key, value in ensure_column_mapping(column_mapping).items()
298        }
299
300        schema = self.find(normalized_table, raise_on_missing=False)
301        if schema and not normalized_column_mapping:
302            return
303
304        parts = self.table_parts(normalized_table)
305
306        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
307        new_trie([parts], self.mapping_trie)
308
309    def column_names(
310        self,
311        table: exp.Table | str,
312        only_visible: bool = False,
313        dialect: DialectType = None,
314        normalize: t.Optional[bool] = None,
315    ) -> t.List[str]:
316        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
317
318        schema = self.find(normalized_table)
319        if schema is None:
320            return []
321
322        if not only_visible or not self.visible:
323            return list(schema)
324
325        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
326        return [col for col in schema if col in visible]
327
328    def get_column_type(
329        self,
330        table: exp.Table | str,
331        column: exp.Column | str,
332        dialect: DialectType = None,
333        normalize: t.Optional[bool] = None,
334    ) -> exp.DataType:
335        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
336
337        normalized_column_name = self._normalize_name(
338            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
339        )
340
341        table_schema = self.find(normalized_table, raise_on_missing=False)
342        if table_schema:
343            column_type = table_schema.get(normalized_column_name)
344
345            if isinstance(column_type, exp.DataType):
346                return column_type
347            elif isinstance(column_type, str):
348                return self._to_data_type(column_type, dialect=dialect)
349
350        return exp.DataType.build("unknown")
351
352    def has_column(
353        self,
354        table: exp.Table | str,
355        column: exp.Column | str,
356        dialect: DialectType = None,
357        normalize: t.Optional[bool] = None,
358    ) -> bool:
359        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
360
361        normalized_column_name = self._normalize_name(
362            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
363        )
364
365        table_schema = self.find(normalized_table, raise_on_missing=False)
366        return normalized_column_name in table_schema if table_schema else False
367
368    def _normalize(self, schema: t.Dict) -> t.Dict:
369        """
370        Normalizes all identifiers in the schema.
371
372        Args:
373            schema: the schema to normalize.
374
375        Returns:
376            The normalized schema mapping.
377        """
378        normalized_mapping: t.Dict = {}
379        flattened_schema = flatten_schema(schema)
380        error_msg = "Table {} must match the schema's nesting level: {}."
381
382        for keys in flattened_schema:
383            columns = nested_get(schema, *zip(keys, keys))
384
385            if not isinstance(columns, dict):
386                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
387            if not columns:
388                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
389            if isinstance(first(columns.values()), dict):
390                raise SchemaError(
391                    error_msg.format(
392                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
393                    ),
394                )
395
396            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
397            for column_name, column_type in columns.items():
398                nested_set(
399                    normalized_mapping,
400                    normalized_keys + [self._normalize_name(column_name)],
401                    column_type,
402                )
403
404        return normalized_mapping
405
406    def _normalize_table(
407        self,
408        table: exp.Table | str,
409        dialect: DialectType = None,
410        normalize: t.Optional[bool] = None,
411    ) -> exp.Table:
412        dialect = dialect or self.dialect
413        normalize = self.normalize if normalize is None else normalize
414
415        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
416
417        if normalize:
418            for part in normalized_table.parts:
419                if isinstance(part, exp.Identifier):
420                    part.replace(
421                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
422                    )
423
424        return normalized_table
425
426    def _normalize_name(
427        self,
428        name: str | exp.Identifier,
429        dialect: DialectType = None,
430        is_table: bool = False,
431        normalize: t.Optional[bool] = None,
432    ) -> str:
433        return normalize_name(
434            name,
435            dialect=dialect or self.dialect,
436            is_table=is_table,
437            normalize=self.normalize if normalize is None else normalize,
438        ).name
439
440    def depth(self) -> int:
441        if not self.empty and not self._depth:
442            # The columns themselves are a mapping, but we don't want to include those
443            self._depth = super().depth() - 1
444        return self._depth
445
446    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
447        """
448        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
449
450        Args:
451            schema_type: the type we want to convert.
452            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
453
454        Returns:
455            The resulting expression type.
456        """
457        if schema_type not in self._type_mapping_cache:
458            dialect = dialect or self.dialect
459            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
460
461            try:
462                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
463                self._type_mapping_cache[schema_type] = expression
464            except AttributeError:
465                in_dialect = f" in dialect {dialect}" if dialect else ""
466                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
467
468        return self._type_mapping_cache[schema_type]
469
470
471def normalize_name(
472    identifier: str | exp.Identifier,
473    dialect: DialectType = None,
474    is_table: bool = False,
475    normalize: t.Optional[bool] = True,
476) -> exp.Identifier:
477    if isinstance(identifier, str):
478        identifier = exp.parse_identifier(identifier, dialect=dialect)
479
480    if not normalize:
481        return identifier
482
483    # this is used for normalize_identifier, bigquery has special rules pertaining tables
484    identifier.meta["is_table"] = is_table
485    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
486
487
488def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
489    if isinstance(schema, Schema):
490        return schema
491
492    return MappingSchema(schema, **kwargs)
493
494
495def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
496    if mapping is None:
497        return {}
498    elif isinstance(mapping, dict):
499        return mapping
500    elif isinstance(mapping, str):
501        col_name_type_strs = [x.strip() for x in mapping.split(",")]
502        return {
503            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
504            for name_type_str in col_name_type_strs
505        }
506    elif isinstance(mapping, list):
507        return {x.strip(): None for x in mapping}
508
509    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
510
511
512def flatten_schema(
513    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
514) -> t.List[t.List[str]]:
515    tables = []
516    keys = keys or []
517    depth = dict_depth(schema) - 1 if depth is None else depth
518
519    for k, v in schema.items():
520        if depth == 1 or not isinstance(v, dict):
521            tables.append(keys + [k])
522        elif depth >= 2:
523            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
524
525    return tables
526
527
528def nested_get(
529    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
530) -> t.Optional[t.Any]:
531    """
532    Get a value for a nested dictionary.
533
534    Args:
535        d: the dictionary to search.
536        *path: tuples of (name, key), where:
537            `key` is the key in the dictionary to get.
538            `name` is a string to use in the error if `key` isn't found.
539
540    Returns:
541        The value or None if it doesn't exist.
542    """
543    for name, key in path:
544        d = d.get(key)  # type: ignore
545        if d is None:
546            if raise_on_missing:
547                name = "table" if name == "this" else name
548                raise ValueError(f"Unknown {name}: {key}")
549            return None
550
551    return d
552
553
554def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
555    """
556    In-place set a value for a nested dictionary
557
558    Example:
559        >>> nested_set({}, ["top_key", "second_key"], "value")
560        {'top_key': {'second_key': 'value'}}
561
562        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
563        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
564
565    Args:
566        d: dictionary to update.
567        keys: the keys that makeup the path to `value`.
568        value: the value to set in the dictionary for the given key path.
569
570    Returns:
571        The (possibly) updated dictionary.
572    """
573    if not keys:
574        return d
575
576    if len(keys) == 1:
577        d[keys[0]] = value
578        return d
579
580    subd = d
581    for key in keys[:-1]:
582        if key not in subd:
583            subd = subd.setdefault(key, {})
584        else:
585            subd = subd[key]
586
587    subd[keys[-1]] = value
588    return d
class Schema(abc.ABC):
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.Sequence[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The sequence of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
24    @abc.abstractmethod
25    def add_table(
26        self,
27        table: exp.Table | str,
28        column_mapping: t.Optional[ColumnMapping] = None,
29        dialect: DialectType = None,
30        normalize: t.Optional[bool] = None,
31        match_depth: bool = True,
32    ) -> None:
33        """
34        Register or update a table. Some implementing classes may require column information to also be provided.
35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
36
37        Args:
38            table: the `Table` expression instance or string representing the table.
39            column_mapping: a column mapping that describes the structure of the table.
40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
41            normalize: whether to normalize identifiers according to the dialect of interest.
42            match_depth: whether to enforce that the table must match the schema's depth or not.
43        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51        normalize: t.Optional[bool] = None,
52    ) -> t.Sequence[str]:
53        """
54        Get the column names for a table.
55
56        Args:
57            table: the `Table` expression instance.
58            only_visible: whether to include invisible columns.
59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
60            normalize: whether to normalize identifiers according to the dialect of interest.
61
62        Returns:
63            The sequence of column names.
64        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
66    @abc.abstractmethod
67    def get_column_type(
68        self,
69        table: exp.Table | str,
70        column: exp.Column | str,
71        dialect: DialectType = None,
72        normalize: t.Optional[bool] = None,
73    ) -> exp.DataType:
74        """
75        Get the `sqlglot.exp.DataType` type of a column in the schema.
76
77        Args:
78            table: the source table.
79            column: the target column.
80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
81            normalize: whether to normalize identifiers according to the dialect of interest.
82
83        Returns:
84            The resulting column type.
85        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        return [part.name for part in reversed(table.parts)]
156
157    def find(
158        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
159    ) -> t.Optional[t.Any]:
160        """
161        Returns the schema of a given table.
162
163        Args:
164            table: the target table.
165            raise_on_missing: whether to raise in case the schema is not found.
166            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
167
168        Returns:
169            The schema of the target table.
170        """
171        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
172        value, trie = in_trie(self.mapping_trie, parts)
173
174        if value == TrieResult.FAILED:
175            return None
176
177        if value == TrieResult.PREFIX:
178            possibilities = flatten_schema(trie)
179
180            if len(possibilities) == 1:
181                parts.extend(possibilities[0])
182            else:
183                message = ", ".join(".".join(parts) for parts in possibilities)
184                if raise_on_missing:
185                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
186                return None
187
188        return self.nested_get(parts, raise_on_missing=raise_on_missing)
189
190    def nested_get(
191        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
192    ) -> t.Optional[t.Any]:
193        return nested_get(
194            d or self.mapping,
195            *zip(self.supported_table_args, reversed(parts)),
196            raise_on_missing=raise_on_missing,
197        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
def depth(self) -> int:
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        return [part.name for part in reversed(table.parts)]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
157    def find(
158        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
159    ) -> t.Optional[t.Any]:
160        """
161        Returns the schema of a given table.
162
163        Args:
164            table: the target table.
165            raise_on_missing: whether to raise in case the schema is not found.
166            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
167
168        Returns:
169            The schema of the target table.
170        """
171        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
172        value, trie = in_trie(self.mapping_trie, parts)
173
174        if value == TrieResult.FAILED:
175            return None
176
177        if value == TrieResult.PREFIX:
178            possibilities = flatten_schema(trie)
179
180            if len(possibilities) == 1:
181                parts.extend(possibilities[0])
182            else:
183                message = ", ".join(".".join(parts) for parts in possibilities)
184                if raise_on_missing:
185                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
186                return None
187
188        return self.nested_get(parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
190    def nested_get(
191        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
192    ) -> t.Optional[t.Any]:
193        return nested_get(
194            d or self.mapping,
195            *zip(self.supported_table_args, reversed(parts)),
196            raise_on_missing=raise_on_missing,
197        )
class MappingSchema(AbstractMappingSchema, Schema):
200class MappingSchema(AbstractMappingSchema, Schema):
201    """
202    Schema based on a nested mapping.
203
204    Args:
205        schema: Mapping in one of the following forms:
206            1. {table: {col: type}}
207            2. {db: {table: {col: type}}}
208            3. {catalog: {db: {table: {col: type}}}}
209            4. None - Tables will be added later
210        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
211            are assumed to be visible. The nesting should mirror that of the schema:
212            1. {table: set(*cols)}}
213            2. {db: {table: set(*cols)}}}
214            3. {catalog: {db: {table: set(*cols)}}}}
215        dialect: The dialect to be used for custom type mappings & parsing string arguments.
216        normalize: Whether to normalize identifier names according to the given dialect or not.
217    """
218
219    def __init__(
220        self,
221        schema: t.Optional[t.Dict] = None,
222        visible: t.Optional[t.Dict] = None,
223        dialect: DialectType = None,
224        normalize: bool = True,
225    ) -> None:
226        self.dialect = dialect
227        self.visible = {} if visible is None else visible
228        self.normalize = normalize
229        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
230        self._depth = 0
231        schema = {} if schema is None else schema
232
233        super().__init__(self._normalize(schema) if self.normalize else schema)
234
235    @classmethod
236    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
237        return MappingSchema(
238            schema=mapping_schema.mapping,
239            visible=mapping_schema.visible,
240            dialect=mapping_schema.dialect,
241            normalize=mapping_schema.normalize,
242        )
243
244    def find(
245        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
246    ) -> t.Optional[t.Any]:
247        schema = super().find(
248            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
249        )
250        if ensure_data_types and isinstance(schema, dict):
251            schema = {
252                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
253                for col, dtype in schema.items()
254            }
255
256        return schema
257
258    def copy(self, **kwargs) -> MappingSchema:
259        return MappingSchema(
260            **{  # type: ignore
261                "schema": self.mapping.copy(),
262                "visible": self.visible.copy(),
263                "dialect": self.dialect,
264                "normalize": self.normalize,
265                **kwargs,
266            }
267        )
268
269    def add_table(
270        self,
271        table: exp.Table | str,
272        column_mapping: t.Optional[ColumnMapping] = None,
273        dialect: DialectType = None,
274        normalize: t.Optional[bool] = None,
275        match_depth: bool = True,
276    ) -> None:
277        """
278        Register or update a table. Updates are only performed if a new column mapping is provided.
279        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
280
281        Args:
282            table: the `Table` expression instance or string representing the table.
283            column_mapping: a column mapping that describes the structure of the table.
284            dialect: the SQL dialect that will be used to parse `table` if it's a string.
285            normalize: whether to normalize identifiers according to the dialect of interest.
286            match_depth: whether to enforce that the table must match the schema's depth or not.
287        """
288        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
289
290        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
291            raise SchemaError(
292                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
293                f"schema's nesting level: {self.depth()}."
294            )
295
296        normalized_column_mapping = {
297            self._normalize_name(key, dialect=dialect, normalize=normalize): value
298            for key, value in ensure_column_mapping(column_mapping).items()
299        }
300
301        schema = self.find(normalized_table, raise_on_missing=False)
302        if schema and not normalized_column_mapping:
303            return
304
305        parts = self.table_parts(normalized_table)
306
307        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
308        new_trie([parts], self.mapping_trie)
309
310    def column_names(
311        self,
312        table: exp.Table | str,
313        only_visible: bool = False,
314        dialect: DialectType = None,
315        normalize: t.Optional[bool] = None,
316    ) -> t.List[str]:
317        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
318
319        schema = self.find(normalized_table)
320        if schema is None:
321            return []
322
323        if not only_visible or not self.visible:
324            return list(schema)
325
326        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
327        return [col for col in schema if col in visible]
328
329    def get_column_type(
330        self,
331        table: exp.Table | str,
332        column: exp.Column | str,
333        dialect: DialectType = None,
334        normalize: t.Optional[bool] = None,
335    ) -> exp.DataType:
336        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
337
338        normalized_column_name = self._normalize_name(
339            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
340        )
341
342        table_schema = self.find(normalized_table, raise_on_missing=False)
343        if table_schema:
344            column_type = table_schema.get(normalized_column_name)
345
346            if isinstance(column_type, exp.DataType):
347                return column_type
348            elif isinstance(column_type, str):
349                return self._to_data_type(column_type, dialect=dialect)
350
351        return exp.DataType.build("unknown")
352
353    def has_column(
354        self,
355        table: exp.Table | str,
356        column: exp.Column | str,
357        dialect: DialectType = None,
358        normalize: t.Optional[bool] = None,
359    ) -> bool:
360        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
361
362        normalized_column_name = self._normalize_name(
363            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
364        )
365
366        table_schema = self.find(normalized_table, raise_on_missing=False)
367        return normalized_column_name in table_schema if table_schema else False
368
369    def _normalize(self, schema: t.Dict) -> t.Dict:
370        """
371        Normalizes all identifiers in the schema.
372
373        Args:
374            schema: the schema to normalize.
375
376        Returns:
377            The normalized schema mapping.
378        """
379        normalized_mapping: t.Dict = {}
380        flattened_schema = flatten_schema(schema)
381        error_msg = "Table {} must match the schema's nesting level: {}."
382
383        for keys in flattened_schema:
384            columns = nested_get(schema, *zip(keys, keys))
385
386            if not isinstance(columns, dict):
387                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
388            if not columns:
389                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
390            if isinstance(first(columns.values()), dict):
391                raise SchemaError(
392                    error_msg.format(
393                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
394                    ),
395                )
396
397            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
398            for column_name, column_type in columns.items():
399                nested_set(
400                    normalized_mapping,
401                    normalized_keys + [self._normalize_name(column_name)],
402                    column_type,
403                )
404
405        return normalized_mapping
406
407    def _normalize_table(
408        self,
409        table: exp.Table | str,
410        dialect: DialectType = None,
411        normalize: t.Optional[bool] = None,
412    ) -> exp.Table:
413        dialect = dialect or self.dialect
414        normalize = self.normalize if normalize is None else normalize
415
416        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
417
418        if normalize:
419            for part in normalized_table.parts:
420                if isinstance(part, exp.Identifier):
421                    part.replace(
422                        normalize_name(part, dialect=dialect, is_table=True, normalize=normalize)
423                    )
424
425        return normalized_table
426
427    def _normalize_name(
428        self,
429        name: str | exp.Identifier,
430        dialect: DialectType = None,
431        is_table: bool = False,
432        normalize: t.Optional[bool] = None,
433    ) -> str:
434        return normalize_name(
435            name,
436            dialect=dialect or self.dialect,
437            is_table=is_table,
438            normalize=self.normalize if normalize is None else normalize,
439        ).name
440
441    def depth(self) -> int:
442        if not self.empty and not self._depth:
443            # The columns themselves are a mapping, but we don't want to include those
444            self._depth = super().depth() - 1
445        return self._depth
446
447    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
448        """
449        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
450
451        Args:
452            schema_type: the type we want to convert.
453            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
454
455        Returns:
456            The resulting expression type.
457        """
458        if schema_type not in self._type_mapping_cache:
459            dialect = dialect or self.dialect
460            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
461
462            try:
463                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
464                self._type_mapping_cache[schema_type] = expression
465            except AttributeError:
466                in_dialect = f" in dialect {dialect}" if dialect else ""
467                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
468
469        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
219    def __init__(
220        self,
221        schema: t.Optional[t.Dict] = None,
222        visible: t.Optional[t.Dict] = None,
223        dialect: DialectType = None,
224        normalize: bool = True,
225    ) -> None:
226        self.dialect = dialect
227        self.visible = {} if visible is None else visible
228        self.normalize = normalize
229        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
230        self._depth = 0
231        schema = {} if schema is None else schema
232
233        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
235    @classmethod
236    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
237        return MappingSchema(
238            schema=mapping_schema.mapping,
239            visible=mapping_schema.visible,
240            dialect=mapping_schema.dialect,
241            normalize=mapping_schema.normalize,
242        )
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
244    def find(
245        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
246    ) -> t.Optional[t.Any]:
247        schema = super().find(
248            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
249        )
250        if ensure_data_types and isinstance(schema, dict):
251            schema = {
252                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
253                for col, dtype in schema.items()
254            }
255
256        return schema

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def copy(self, **kwargs) -> MappingSchema:
258    def copy(self, **kwargs) -> MappingSchema:
259        return MappingSchema(
260            **{  # type: ignore
261                "schema": self.mapping.copy(),
262                "visible": self.visible.copy(),
263                "dialect": self.dialect,
264                "normalize": self.normalize,
265                **kwargs,
266            }
267        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
269    def add_table(
270        self,
271        table: exp.Table | str,
272        column_mapping: t.Optional[ColumnMapping] = None,
273        dialect: DialectType = None,
274        normalize: t.Optional[bool] = None,
275        match_depth: bool = True,
276    ) -> None:
277        """
278        Register or update a table. Updates are only performed if a new column mapping is provided.
279        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
280
281        Args:
282            table: the `Table` expression instance or string representing the table.
283            column_mapping: a column mapping that describes the structure of the table.
284            dialect: the SQL dialect that will be used to parse `table` if it's a string.
285            normalize: whether to normalize identifiers according to the dialect of interest.
286            match_depth: whether to enforce that the table must match the schema's depth or not.
287        """
288        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
289
290        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
291            raise SchemaError(
292                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
293                f"schema's nesting level: {self.depth()}."
294            )
295
296        normalized_column_mapping = {
297            self._normalize_name(key, dialect=dialect, normalize=normalize): value
298            for key, value in ensure_column_mapping(column_mapping).items()
299        }
300
301        schema = self.find(normalized_table, raise_on_missing=False)
302        if schema and not normalized_column_mapping:
303            return
304
305        parts = self.table_parts(normalized_table)
306
307        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
308        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
310    def column_names(
311        self,
312        table: exp.Table | str,
313        only_visible: bool = False,
314        dialect: DialectType = None,
315        normalize: t.Optional[bool] = None,
316    ) -> t.List[str]:
317        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
318
319        schema = self.find(normalized_table)
320        if schema is None:
321            return []
322
323        if not only_visible or not self.visible:
324            return list(schema)
325
326        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
327        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
329    def get_column_type(
330        self,
331        table: exp.Table | str,
332        column: exp.Column | str,
333        dialect: DialectType = None,
334        normalize: t.Optional[bool] = None,
335    ) -> exp.DataType:
336        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
337
338        normalized_column_name = self._normalize_name(
339            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
340        )
341
342        table_schema = self.find(normalized_table, raise_on_missing=False)
343        if table_schema:
344            column_type = table_schema.get(normalized_column_name)
345
346            if isinstance(column_type, exp.DataType):
347                return column_type
348            elif isinstance(column_type, str):
349                return self._to_data_type(column_type, dialect=dialect)
350
351        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
353    def has_column(
354        self,
355        table: exp.Table | str,
356        column: exp.Column | str,
357        dialect: DialectType = None,
358        normalize: t.Optional[bool] = None,
359    ) -> bool:
360        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
361
362        normalized_column_name = self._normalize_name(
363            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
364        )
365
366        table_schema = self.find(normalized_table, raise_on_missing=False)
367        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
441    def depth(self) -> int:
442        if not self.empty and not self._depth:
443            # The columns themselves are a mapping, but we don't want to include those
444            self._depth = super().depth() - 1
445        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
472def normalize_name(
473    identifier: str | exp.Identifier,
474    dialect: DialectType = None,
475    is_table: bool = False,
476    normalize: t.Optional[bool] = True,
477) -> exp.Identifier:
478    if isinstance(identifier, str):
479        identifier = exp.parse_identifier(identifier, dialect=dialect)
480
481    if not normalize:
482        return identifier
483
484    # this is used for normalize_identifier, bigquery has special rules pertaining tables
485    identifier.meta["is_table"] = is_table
486    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
489def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
490    if isinstance(schema, Schema):
491        return schema
492
493    return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: Union[Dict, str, List, NoneType]) -> Dict:
496def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
497    if mapping is None:
498        return {}
499    elif isinstance(mapping, dict):
500        return mapping
501    elif isinstance(mapping, str):
502        col_name_type_strs = [x.strip() for x in mapping.split(",")]
503        return {
504            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
505            for name_type_str in col_name_type_strs
506        }
507    elif isinstance(mapping, list):
508        return {x.strip(): None for x in mapping}
509
510    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: Optional[int] = None, keys: Optional[List[str]] = None) -> List[List[str]]:
513def flatten_schema(
514    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
515) -> t.List[t.List[str]]:
516    tables = []
517    keys = keys or []
518    depth = dict_depth(schema) - 1 if depth is None else depth
519
520    for k, v in schema.items():
521        if depth == 1 or not isinstance(v, dict):
522            tables.append(keys + [k])
523        elif depth >= 2:
524            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
525
526    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
529def nested_get(
530    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
531) -> t.Optional[t.Any]:
532    """
533    Get a value for a nested dictionary.
534
535    Args:
536        d: the dictionary to search.
537        *path: tuples of (name, key), where:
538            `key` is the key in the dictionary to get.
539            `name` is a string to use in the error if `key` isn't found.
540
541    Returns:
542        The value or None if it doesn't exist.
543    """
544    for name, key in path:
545        d = d.get(key)  # type: ignore
546        if d is None:
547            if raise_on_missing:
548                name = "table" if name == "this" else name
549                raise ValueError(f"Unknown {name}: {key}")
550            return None
551
552    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
555def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
556    """
557    In-place set a value for a nested dictionary
558
559    Example:
560        >>> nested_set({}, ["top_key", "second_key"], "value")
561        {'top_key': {'second_key': 'value'}}
562
563        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
564        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
565
566    Args:
567        d: dictionary to update.
568        keys: the keys that makeup the path to `value`.
569        value: the value to set in the dictionary for the given key path.
570
571    Returns:
572        The (possibly) updated dictionary.
573    """
574    if not keys:
575        return d
576
577    if len(keys) == 1:
578        d[keys[0]] = value
579        return d
580
581    subd = d
582    for key in keys[:-1]:
583        if key not in subd:
584            subd = subd.setdefault(key, {})
585        else:
586            subd = subd[key]
587
588    subd[keys[-1]] = value
589    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.