Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth, first
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15    ColumnMapping = t.Union[t.Dict, str, t.List]
 16
 17
 18class Schema(abc.ABC):
 19    """Abstract base class for database schemas"""
 20
 21    dialect: DialectType
 22
 23    @abc.abstractmethod
 24    def add_table(
 25        self,
 26        table: exp.Table | str,
 27        column_mapping: t.Optional[ColumnMapping] = None,
 28        dialect: DialectType = None,
 29        normalize: t.Optional[bool] = None,
 30        match_depth: bool = True,
 31    ) -> None:
 32        """
 33        Register or update a table. Some implementing classes may require column information to also be provided.
 34        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 35
 36        Args:
 37            table: the `Table` expression instance or string representing the table.
 38            column_mapping: a column mapping that describes the structure of the table.
 39            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 40            normalize: whether to normalize identifiers according to the dialect of interest.
 41            match_depth: whether to enforce that the table must match the schema's depth or not.
 42        """
 43
 44    @abc.abstractmethod
 45    def column_names(
 46        self,
 47        table: exp.Table | str,
 48        only_visible: bool = False,
 49        dialect: DialectType = None,
 50        normalize: t.Optional[bool] = None,
 51    ) -> t.Sequence[str]:
 52        """
 53        Get the column names for a table.
 54
 55        Args:
 56            table: the `Table` expression instance.
 57            only_visible: whether to include invisible columns.
 58            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 59            normalize: whether to normalize identifiers according to the dialect of interest.
 60
 61        Returns:
 62            The sequence of column names.
 63        """
 64
 65    @abc.abstractmethod
 66    def get_column_type(
 67        self,
 68        table: exp.Table | str,
 69        column: exp.Column | str,
 70        dialect: DialectType = None,
 71        normalize: t.Optional[bool] = None,
 72    ) -> exp.DataType:
 73        """
 74        Get the `sqlglot.exp.DataType` type of a column in the schema.
 75
 76        Args:
 77            table: the source table.
 78            column: the target column.
 79            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 80            normalize: whether to normalize identifiers according to the dialect of interest.
 81
 82        Returns:
 83            The resulting column type.
 84        """
 85
 86    def has_column(
 87        self,
 88        table: exp.Table | str,
 89        column: exp.Column | str,
 90        dialect: DialectType = None,
 91        normalize: t.Optional[bool] = None,
 92    ) -> bool:
 93        """
 94        Returns whether `column` appears in `table`'s schema.
 95
 96        Args:
 97            table: the source table.
 98            column: the target column.
 99            dialect: the SQL dialect that will be used to parse `table` if it's a string.
100            normalize: whether to normalize identifiers according to the dialect of interest.
101
102        Returns:
103            True if the column appears in the schema, False otherwise.
104        """
105        name = column if isinstance(column, str) else column.name
106        return name in self.column_names(table, dialect=dialect, normalize=normalize)
107
108    @property
109    @abc.abstractmethod
110    def supported_table_args(self) -> t.Tuple[str, ...]:
111        """
112        Table arguments this schema support, e.g. `("this", "db", "catalog")`
113        """
114
115    @property
116    def empty(self) -> bool:
117        """Returns whether the schema is empty."""
118        return True
119
120
121class AbstractMappingSchema:
122    def __init__(
123        self,
124        mapping: t.Optional[t.Dict] = None,
125    ) -> None:
126        self.mapping = mapping or {}
127        self.mapping_trie = new_trie(
128            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
129        )
130        self._supported_table_args: t.Tuple[str, ...] = tuple()
131
132    @property
133    def empty(self) -> bool:
134        return not self.mapping
135
136    def depth(self) -> int:
137        return dict_depth(self.mapping)
138
139    @property
140    def supported_table_args(self) -> t.Tuple[str, ...]:
141        if not self._supported_table_args and self.mapping:
142            depth = self.depth()
143
144            if not depth:  # None
145                self._supported_table_args = tuple()
146            elif 1 <= depth <= 3:
147                self._supported_table_args = exp.TABLE_PARTS[:depth]
148            else:
149                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
150
151        return self._supported_table_args
152
153    def table_parts(self, table: exp.Table) -> t.List[str]:
154        if isinstance(table.this, exp.ReadCSV):
155            return [table.this.name]
156        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
157
158    def find(
159        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
160    ) -> t.Optional[t.Any]:
161        """
162        Returns the schema of a given table.
163
164        Args:
165            table: the target table.
166            raise_on_missing: whether to raise in case the schema is not found.
167            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
168
169        Returns:
170            The schema of the target table.
171        """
172        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
173        value, trie = in_trie(self.mapping_trie, parts)
174
175        if value == TrieResult.FAILED:
176            return None
177
178        if value == TrieResult.PREFIX:
179            possibilities = flatten_schema(trie)
180
181            if len(possibilities) == 1:
182                parts.extend(possibilities[0])
183            else:
184                message = ", ".join(".".join(parts) for parts in possibilities)
185                if raise_on_missing:
186                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
187                return None
188
189        return self.nested_get(parts, raise_on_missing=raise_on_missing)
190
191    def nested_get(
192        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
193    ) -> t.Optional[t.Any]:
194        return nested_get(
195            d or self.mapping,
196            *zip(self.supported_table_args, reversed(parts)),
197            raise_on_missing=raise_on_missing,
198        )
199
200
201class MappingSchema(AbstractMappingSchema, Schema):
202    """
203    Schema based on a nested mapping.
204
205    Args:
206        schema: Mapping in one of the following forms:
207            1. {table: {col: type}}
208            2. {db: {table: {col: type}}}
209            3. {catalog: {db: {table: {col: type}}}}
210            4. None - Tables will be added later
211        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
212            are assumed to be visible. The nesting should mirror that of the schema:
213            1. {table: set(*cols)}}
214            2. {db: {table: set(*cols)}}}
215            3. {catalog: {db: {table: set(*cols)}}}}
216        dialect: The dialect to be used for custom type mappings & parsing string arguments.
217        normalize: Whether to normalize identifier names according to the given dialect or not.
218    """
219
220    def __init__(
221        self,
222        schema: t.Optional[t.Dict] = None,
223        visible: t.Optional[t.Dict] = None,
224        dialect: DialectType = None,
225        normalize: bool = True,
226    ) -> None:
227        self.dialect = dialect
228        self.visible = {} if visible is None else visible
229        self.normalize = normalize
230        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
231        self._depth = 0
232        schema = {} if schema is None else schema
233
234        super().__init__(self._normalize(schema) if self.normalize else schema)
235
236    @classmethod
237    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
238        return MappingSchema(
239            schema=mapping_schema.mapping,
240            visible=mapping_schema.visible,
241            dialect=mapping_schema.dialect,
242            normalize=mapping_schema.normalize,
243        )
244
245    def find(
246        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
247    ) -> t.Optional[t.Any]:
248        schema = super().find(
249            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
250        )
251        if ensure_data_types and isinstance(schema, dict):
252            schema = {
253                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
254                for col, dtype in schema.items()
255            }
256
257        return schema
258
259    def copy(self, **kwargs) -> MappingSchema:
260        return MappingSchema(
261            **{  # type: ignore
262                "schema": self.mapping.copy(),
263                "visible": self.visible.copy(),
264                "dialect": self.dialect,
265                "normalize": self.normalize,
266                **kwargs,
267            }
268        )
269
270    def add_table(
271        self,
272        table: exp.Table | str,
273        column_mapping: t.Optional[ColumnMapping] = None,
274        dialect: DialectType = None,
275        normalize: t.Optional[bool] = None,
276        match_depth: bool = True,
277    ) -> None:
278        """
279        Register or update a table. Updates are only performed if a new column mapping is provided.
280        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
281
282        Args:
283            table: the `Table` expression instance or string representing the table.
284            column_mapping: a column mapping that describes the structure of the table.
285            dialect: the SQL dialect that will be used to parse `table` if it's a string.
286            normalize: whether to normalize identifiers according to the dialect of interest.
287            match_depth: whether to enforce that the table must match the schema's depth or not.
288        """
289        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
290
291        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
292            raise SchemaError(
293                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
294                f"schema's nesting level: {self.depth()}."
295            )
296
297        normalized_column_mapping = {
298            self._normalize_name(key, dialect=dialect, normalize=normalize): value
299            for key, value in ensure_column_mapping(column_mapping).items()
300        }
301
302        schema = self.find(normalized_table, raise_on_missing=False)
303        if schema and not normalized_column_mapping:
304            return
305
306        parts = self.table_parts(normalized_table)
307
308        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
309        new_trie([parts], self.mapping_trie)
310
311    def column_names(
312        self,
313        table: exp.Table | str,
314        only_visible: bool = False,
315        dialect: DialectType = None,
316        normalize: t.Optional[bool] = None,
317    ) -> t.List[str]:
318        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
319
320        schema = self.find(normalized_table)
321        if schema is None:
322            return []
323
324        if not only_visible or not self.visible:
325            return list(schema)
326
327        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
328        return [col for col in schema if col in visible]
329
330    def get_column_type(
331        self,
332        table: exp.Table | str,
333        column: exp.Column | str,
334        dialect: DialectType = None,
335        normalize: t.Optional[bool] = None,
336    ) -> exp.DataType:
337        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
338
339        normalized_column_name = self._normalize_name(
340            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
341        )
342
343        table_schema = self.find(normalized_table, raise_on_missing=False)
344        if table_schema:
345            column_type = table_schema.get(normalized_column_name)
346
347            if isinstance(column_type, exp.DataType):
348                return column_type
349            elif isinstance(column_type, str):
350                return self._to_data_type(column_type, dialect=dialect)
351
352        return exp.DataType.build("unknown")
353
354    def has_column(
355        self,
356        table: exp.Table | str,
357        column: exp.Column | str,
358        dialect: DialectType = None,
359        normalize: t.Optional[bool] = None,
360    ) -> bool:
361        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
362
363        normalized_column_name = self._normalize_name(
364            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
365        )
366
367        table_schema = self.find(normalized_table, raise_on_missing=False)
368        return normalized_column_name in table_schema if table_schema else False
369
370    def _normalize(self, schema: t.Dict) -> t.Dict:
371        """
372        Normalizes all identifiers in the schema.
373
374        Args:
375            schema: the schema to normalize.
376
377        Returns:
378            The normalized schema mapping.
379        """
380        normalized_mapping: t.Dict = {}
381        flattened_schema = flatten_schema(schema)
382        error_msg = "Table {} must match the schema's nesting level: {}."
383
384        for keys in flattened_schema:
385            columns = nested_get(schema, *zip(keys, keys))
386
387            if not isinstance(columns, dict):
388                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
389            if not columns:
390                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
391            if isinstance(first(columns.values()), dict):
392                raise SchemaError(
393                    error_msg.format(
394                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
395                    ),
396                )
397
398            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
399            for column_name, column_type in columns.items():
400                nested_set(
401                    normalized_mapping,
402                    normalized_keys + [self._normalize_name(column_name)],
403                    column_type,
404                )
405
406        return normalized_mapping
407
408    def _normalize_table(
409        self,
410        table: exp.Table | str,
411        dialect: DialectType = None,
412        normalize: t.Optional[bool] = None,
413    ) -> exp.Table:
414        dialect = dialect or self.dialect
415        normalize = self.normalize if normalize is None else normalize
416
417        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
418
419        if normalize:
420            for arg in exp.TABLE_PARTS:
421                value = normalized_table.args.get(arg)
422                if isinstance(value, exp.Identifier):
423                    normalized_table.set(
424                        arg,
425                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
426                    )
427
428        return normalized_table
429
430    def _normalize_name(
431        self,
432        name: str | exp.Identifier,
433        dialect: DialectType = None,
434        is_table: bool = False,
435        normalize: t.Optional[bool] = None,
436    ) -> str:
437        return normalize_name(
438            name,
439            dialect=dialect or self.dialect,
440            is_table=is_table,
441            normalize=self.normalize if normalize is None else normalize,
442        ).name
443
444    def depth(self) -> int:
445        if not self.empty and not self._depth:
446            # The columns themselves are a mapping, but we don't want to include those
447            self._depth = super().depth() - 1
448        return self._depth
449
450    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
451        """
452        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
453
454        Args:
455            schema_type: the type we want to convert.
456            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
457
458        Returns:
459            The resulting expression type.
460        """
461        if schema_type not in self._type_mapping_cache:
462            dialect = dialect or self.dialect
463            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
464
465            try:
466                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
467                self._type_mapping_cache[schema_type] = expression
468            except AttributeError:
469                in_dialect = f" in dialect {dialect}" if dialect else ""
470                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
471
472        return self._type_mapping_cache[schema_type]
473
474
475def normalize_name(
476    identifier: str | exp.Identifier,
477    dialect: DialectType = None,
478    is_table: bool = False,
479    normalize: t.Optional[bool] = True,
480) -> exp.Identifier:
481    if isinstance(identifier, str):
482        identifier = exp.parse_identifier(identifier, dialect=dialect)
483
484    if not normalize:
485        return identifier
486
487    # this is used for normalize_identifier, bigquery has special rules pertaining tables
488    identifier.meta["is_table"] = is_table
489    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
490
491
492def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
493    if isinstance(schema, Schema):
494        return schema
495
496    return MappingSchema(schema, **kwargs)
497
498
499def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
500    if mapping is None:
501        return {}
502    elif isinstance(mapping, dict):
503        return mapping
504    elif isinstance(mapping, str):
505        col_name_type_strs = [x.strip() for x in mapping.split(",")]
506        return {
507            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
508            for name_type_str in col_name_type_strs
509        }
510    elif isinstance(mapping, list):
511        return {x.strip(): None for x in mapping}
512
513    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
514
515
516def flatten_schema(
517    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
518) -> t.List[t.List[str]]:
519    tables = []
520    keys = keys or []
521    depth = dict_depth(schema) - 1 if depth is None else depth
522
523    for k, v in schema.items():
524        if depth == 1 or not isinstance(v, dict):
525            tables.append(keys + [k])
526        elif depth >= 2:
527            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
528
529    return tables
530
531
532def nested_get(
533    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
534) -> t.Optional[t.Any]:
535    """
536    Get a value for a nested dictionary.
537
538    Args:
539        d: the dictionary to search.
540        *path: tuples of (name, key), where:
541            `key` is the key in the dictionary to get.
542            `name` is a string to use in the error if `key` isn't found.
543
544    Returns:
545        The value or None if it doesn't exist.
546    """
547    for name, key in path:
548        d = d.get(key)  # type: ignore
549        if d is None:
550            if raise_on_missing:
551                name = "table" if name == "this" else name
552                raise ValueError(f"Unknown {name}: {key}")
553            return None
554
555    return d
556
557
558def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
559    """
560    In-place set a value for a nested dictionary
561
562    Example:
563        >>> nested_set({}, ["top_key", "second_key"], "value")
564        {'top_key': {'second_key': 'value'}}
565
566        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
567        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
568
569    Args:
570        d: dictionary to update.
571        keys: the keys that makeup the path to `value`.
572        value: the value to set in the dictionary for the given key path.
573
574    Returns:
575        The (possibly) updated dictionary.
576    """
577    if not keys:
578        return d
579
580    if len(keys) == 1:
581        d[keys[0]] = value
582        return d
583
584    subd = d
585    for key in keys[:-1]:
586        if key not in subd:
587            subd = subd.setdefault(key, {})
588        else:
589            subd = subd[key]
590
591    subd[keys[-1]] = value
592    return d
class Schema(abc.ABC):
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.Sequence[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The sequence of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
24    @abc.abstractmethod
25    def add_table(
26        self,
27        table: exp.Table | str,
28        column_mapping: t.Optional[ColumnMapping] = None,
29        dialect: DialectType = None,
30        normalize: t.Optional[bool] = None,
31        match_depth: bool = True,
32    ) -> None:
33        """
34        Register or update a table. Some implementing classes may require column information to also be provided.
35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
36
37        Args:
38            table: the `Table` expression instance or string representing the table.
39            column_mapping: a column mapping that describes the structure of the table.
40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
41            normalize: whether to normalize identifiers according to the dialect of interest.
42            match_depth: whether to enforce that the table must match the schema's depth or not.
43        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
45    @abc.abstractmethod
46    def column_names(
47        self,
48        table: exp.Table | str,
49        only_visible: bool = False,
50        dialect: DialectType = None,
51        normalize: t.Optional[bool] = None,
52    ) -> t.Sequence[str]:
53        """
54        Get the column names for a table.
55
56        Args:
57            table: the `Table` expression instance.
58            only_visible: whether to include invisible columns.
59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
60            normalize: whether to normalize identifiers according to the dialect of interest.
61
62        Returns:
63            The sequence of column names.
64        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
66    @abc.abstractmethod
67    def get_column_type(
68        self,
69        table: exp.Table | str,
70        column: exp.Column | str,
71        dialect: DialectType = None,
72        normalize: t.Optional[bool] = None,
73    ) -> exp.DataType:
74        """
75        Get the `sqlglot.exp.DataType` type of a column in the schema.
76
77        Args:
78            table: the source table.
79            column: the target column.
80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
81            normalize: whether to normalize identifiers according to the dialect of interest.
82
83        Returns:
84            The resulting column type.
85        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
158
159    def find(
160        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
161    ) -> t.Optional[t.Any]:
162        """
163        Returns the schema of a given table.
164
165        Args:
166            table: the target table.
167            raise_on_missing: whether to raise in case the schema is not found.
168            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
169
170        Returns:
171            The schema of the target table.
172        """
173        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
174        value, trie = in_trie(self.mapping_trie, parts)
175
176        if value == TrieResult.FAILED:
177            return None
178
179        if value == TrieResult.PREFIX:
180            possibilities = flatten_schema(trie)
181
182            if len(possibilities) == 1:
183                parts.extend(possibilities[0])
184            else:
185                message = ", ".join(".".join(parts) for parts in possibilities)
186                if raise_on_missing:
187                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
188                return None
189
190        return self.nested_get(parts, raise_on_missing=raise_on_missing)
191
192    def nested_get(
193        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
194    ) -> t.Optional[t.Any]:
195        return nested_get(
196            d or self.mapping,
197            *zip(self.supported_table_args, reversed(parts)),
198            raise_on_missing=raise_on_missing,
199        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
def depth(self) -> int:
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
159    def find(
160        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
161    ) -> t.Optional[t.Any]:
162        """
163        Returns the schema of a given table.
164
165        Args:
166            table: the target table.
167            raise_on_missing: whether to raise in case the schema is not found.
168            ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
169
170        Returns:
171            The schema of the target table.
172        """
173        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
174        value, trie = in_trie(self.mapping_trie, parts)
175
176        if value == TrieResult.FAILED:
177            return None
178
179        if value == TrieResult.PREFIX:
180            possibilities = flatten_schema(trie)
181
182            if len(possibilities) == 1:
183                parts.extend(possibilities[0])
184            else:
185                message = ", ".join(".".join(parts) for parts in possibilities)
186                if raise_on_missing:
187                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
188                return None
189
190        return self.nested_get(parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
192    def nested_get(
193        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
194    ) -> t.Optional[t.Any]:
195        return nested_get(
196            d or self.mapping,
197            *zip(self.supported_table_args, reversed(parts)),
198            raise_on_missing=raise_on_missing,
199        )
class MappingSchema(AbstractMappingSchema, Schema):
202class MappingSchema(AbstractMappingSchema, Schema):
203    """
204    Schema based on a nested mapping.
205
206    Args:
207        schema: Mapping in one of the following forms:
208            1. {table: {col: type}}
209            2. {db: {table: {col: type}}}
210            3. {catalog: {db: {table: {col: type}}}}
211            4. None - Tables will be added later
212        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
213            are assumed to be visible. The nesting should mirror that of the schema:
214            1. {table: set(*cols)}}
215            2. {db: {table: set(*cols)}}}
216            3. {catalog: {db: {table: set(*cols)}}}}
217        dialect: The dialect to be used for custom type mappings & parsing string arguments.
218        normalize: Whether to normalize identifier names according to the given dialect or not.
219    """
220
221    def __init__(
222        self,
223        schema: t.Optional[t.Dict] = None,
224        visible: t.Optional[t.Dict] = None,
225        dialect: DialectType = None,
226        normalize: bool = True,
227    ) -> None:
228        self.dialect = dialect
229        self.visible = {} if visible is None else visible
230        self.normalize = normalize
231        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
232        self._depth = 0
233        schema = {} if schema is None else schema
234
235        super().__init__(self._normalize(schema) if self.normalize else schema)
236
237    @classmethod
238    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
239        return MappingSchema(
240            schema=mapping_schema.mapping,
241            visible=mapping_schema.visible,
242            dialect=mapping_schema.dialect,
243            normalize=mapping_schema.normalize,
244        )
245
246    def find(
247        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
248    ) -> t.Optional[t.Any]:
249        schema = super().find(
250            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
251        )
252        if ensure_data_types and isinstance(schema, dict):
253            schema = {
254                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
255                for col, dtype in schema.items()
256            }
257
258        return schema
259
260    def copy(self, **kwargs) -> MappingSchema:
261        return MappingSchema(
262            **{  # type: ignore
263                "schema": self.mapping.copy(),
264                "visible": self.visible.copy(),
265                "dialect": self.dialect,
266                "normalize": self.normalize,
267                **kwargs,
268            }
269        )
270
271    def add_table(
272        self,
273        table: exp.Table | str,
274        column_mapping: t.Optional[ColumnMapping] = None,
275        dialect: DialectType = None,
276        normalize: t.Optional[bool] = None,
277        match_depth: bool = True,
278    ) -> None:
279        """
280        Register or update a table. Updates are only performed if a new column mapping is provided.
281        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
282
283        Args:
284            table: the `Table` expression instance or string representing the table.
285            column_mapping: a column mapping that describes the structure of the table.
286            dialect: the SQL dialect that will be used to parse `table` if it's a string.
287            normalize: whether to normalize identifiers according to the dialect of interest.
288            match_depth: whether to enforce that the table must match the schema's depth or not.
289        """
290        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
291
292        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
293            raise SchemaError(
294                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
295                f"schema's nesting level: {self.depth()}."
296            )
297
298        normalized_column_mapping = {
299            self._normalize_name(key, dialect=dialect, normalize=normalize): value
300            for key, value in ensure_column_mapping(column_mapping).items()
301        }
302
303        schema = self.find(normalized_table, raise_on_missing=False)
304        if schema and not normalized_column_mapping:
305            return
306
307        parts = self.table_parts(normalized_table)
308
309        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
310        new_trie([parts], self.mapping_trie)
311
312    def column_names(
313        self,
314        table: exp.Table | str,
315        only_visible: bool = False,
316        dialect: DialectType = None,
317        normalize: t.Optional[bool] = None,
318    ) -> t.List[str]:
319        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
320
321        schema = self.find(normalized_table)
322        if schema is None:
323            return []
324
325        if not only_visible or not self.visible:
326            return list(schema)
327
328        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
329        return [col for col in schema if col in visible]
330
331    def get_column_type(
332        self,
333        table: exp.Table | str,
334        column: exp.Column | str,
335        dialect: DialectType = None,
336        normalize: t.Optional[bool] = None,
337    ) -> exp.DataType:
338        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
339
340        normalized_column_name = self._normalize_name(
341            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
342        )
343
344        table_schema = self.find(normalized_table, raise_on_missing=False)
345        if table_schema:
346            column_type = table_schema.get(normalized_column_name)
347
348            if isinstance(column_type, exp.DataType):
349                return column_type
350            elif isinstance(column_type, str):
351                return self._to_data_type(column_type, dialect=dialect)
352
353        return exp.DataType.build("unknown")
354
355    def has_column(
356        self,
357        table: exp.Table | str,
358        column: exp.Column | str,
359        dialect: DialectType = None,
360        normalize: t.Optional[bool] = None,
361    ) -> bool:
362        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
363
364        normalized_column_name = self._normalize_name(
365            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
366        )
367
368        table_schema = self.find(normalized_table, raise_on_missing=False)
369        return normalized_column_name in table_schema if table_schema else False
370
371    def _normalize(self, schema: t.Dict) -> t.Dict:
372        """
373        Normalizes all identifiers in the schema.
374
375        Args:
376            schema: the schema to normalize.
377
378        Returns:
379            The normalized schema mapping.
380        """
381        normalized_mapping: t.Dict = {}
382        flattened_schema = flatten_schema(schema)
383        error_msg = "Table {} must match the schema's nesting level: {}."
384
385        for keys in flattened_schema:
386            columns = nested_get(schema, *zip(keys, keys))
387
388            if not isinstance(columns, dict):
389                raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
390            if not columns:
391                raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
392            if isinstance(first(columns.values()), dict):
393                raise SchemaError(
394                    error_msg.format(
395                        ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
396                    ),
397                )
398
399            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
400            for column_name, column_type in columns.items():
401                nested_set(
402                    normalized_mapping,
403                    normalized_keys + [self._normalize_name(column_name)],
404                    column_type,
405                )
406
407        return normalized_mapping
408
409    def _normalize_table(
410        self,
411        table: exp.Table | str,
412        dialect: DialectType = None,
413        normalize: t.Optional[bool] = None,
414    ) -> exp.Table:
415        dialect = dialect or self.dialect
416        normalize = self.normalize if normalize is None else normalize
417
418        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
419
420        if normalize:
421            for arg in exp.TABLE_PARTS:
422                value = normalized_table.args.get(arg)
423                if isinstance(value, exp.Identifier):
424                    normalized_table.set(
425                        arg,
426                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
427                    )
428
429        return normalized_table
430
431    def _normalize_name(
432        self,
433        name: str | exp.Identifier,
434        dialect: DialectType = None,
435        is_table: bool = False,
436        normalize: t.Optional[bool] = None,
437    ) -> str:
438        return normalize_name(
439            name,
440            dialect=dialect or self.dialect,
441            is_table=is_table,
442            normalize=self.normalize if normalize is None else normalize,
443        ).name
444
445    def depth(self) -> int:
446        if not self.empty and not self._depth:
447            # The columns themselves are a mapping, but we don't want to include those
448            self._depth = super().depth() - 1
449        return self._depth
450
451    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
452        """
453        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
454
455        Args:
456            schema_type: the type we want to convert.
457            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
458
459        Returns:
460            The resulting expression type.
461        """
462        if schema_type not in self._type_mapping_cache:
463            dialect = dialect or self.dialect
464            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
465
466            try:
467                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
468                self._type_mapping_cache[schema_type] = expression
469            except AttributeError:
470                in_dialect = f" in dialect {dialect}" if dialect else ""
471                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
472
473        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
221    def __init__(
222        self,
223        schema: t.Optional[t.Dict] = None,
224        visible: t.Optional[t.Dict] = None,
225        dialect: DialectType = None,
226        normalize: bool = True,
227    ) -> None:
228        self.dialect = dialect
229        self.visible = {} if visible is None else visible
230        self.normalize = normalize
231        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
232        self._depth = 0
233        schema = {} if schema is None else schema
234
235        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
237    @classmethod
238    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
239        return MappingSchema(
240            schema=mapping_schema.mapping,
241            visible=mapping_schema.visible,
242            dialect=mapping_schema.dialect,
243            normalize=mapping_schema.normalize,
244        )
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True, ensure_data_types: bool = False) -> Optional[Any]:
246    def find(
247        self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
248    ) -> t.Optional[t.Any]:
249        schema = super().find(
250            table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
251        )
252        if ensure_data_types and isinstance(schema, dict):
253            schema = {
254                col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
255                for col, dtype in schema.items()
256            }
257
258        return schema

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
  • ensure_data_types: whether to convert str types to their DataType equivalents.
Returns:

The schema of the target table.

def copy(self, **kwargs) -> MappingSchema:
260    def copy(self, **kwargs) -> MappingSchema:
261        return MappingSchema(
262            **{  # type: ignore
263                "schema": self.mapping.copy(),
264                "visible": self.visible.copy(),
265                "dialect": self.dialect,
266                "normalize": self.normalize,
267                **kwargs,
268            }
269        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
271    def add_table(
272        self,
273        table: exp.Table | str,
274        column_mapping: t.Optional[ColumnMapping] = None,
275        dialect: DialectType = None,
276        normalize: t.Optional[bool] = None,
277        match_depth: bool = True,
278    ) -> None:
279        """
280        Register or update a table. Updates are only performed if a new column mapping is provided.
281        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
282
283        Args:
284            table: the `Table` expression instance or string representing the table.
285            column_mapping: a column mapping that describes the structure of the table.
286            dialect: the SQL dialect that will be used to parse `table` if it's a string.
287            normalize: whether to normalize identifiers according to the dialect of interest.
288            match_depth: whether to enforce that the table must match the schema's depth or not.
289        """
290        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
291
292        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
293            raise SchemaError(
294                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
295                f"schema's nesting level: {self.depth()}."
296            )
297
298        normalized_column_mapping = {
299            self._normalize_name(key, dialect=dialect, normalize=normalize): value
300            for key, value in ensure_column_mapping(column_mapping).items()
301        }
302
303        schema = self.find(normalized_table, raise_on_missing=False)
304        if schema and not normalized_column_mapping:
305            return
306
307        parts = self.table_parts(normalized_table)
308
309        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
310        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
312    def column_names(
313        self,
314        table: exp.Table | str,
315        only_visible: bool = False,
316        dialect: DialectType = None,
317        normalize: t.Optional[bool] = None,
318    ) -> t.List[str]:
319        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
320
321        schema = self.find(normalized_table)
322        if schema is None:
323            return []
324
325        if not only_visible or not self.visible:
326            return list(schema)
327
328        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
329        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
331    def get_column_type(
332        self,
333        table: exp.Table | str,
334        column: exp.Column | str,
335        dialect: DialectType = None,
336        normalize: t.Optional[bool] = None,
337    ) -> exp.DataType:
338        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
339
340        normalized_column_name = self._normalize_name(
341            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
342        )
343
344        table_schema = self.find(normalized_table, raise_on_missing=False)
345        if table_schema:
346            column_type = table_schema.get(normalized_column_name)
347
348            if isinstance(column_type, exp.DataType):
349                return column_type
350            elif isinstance(column_type, str):
351                return self._to_data_type(column_type, dialect=dialect)
352
353        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
355    def has_column(
356        self,
357        table: exp.Table | str,
358        column: exp.Column | str,
359        dialect: DialectType = None,
360        normalize: t.Optional[bool] = None,
361    ) -> bool:
362        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
363
364        normalized_column_name = self._normalize_name(
365            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
366        )
367
368        table_schema = self.find(normalized_table, raise_on_missing=False)
369        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
445    def depth(self) -> int:
446        if not self.empty and not self._depth:
447            # The columns themselves are a mapping, but we don't want to include those
448            self._depth = super().depth() - 1
449        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
476def normalize_name(
477    identifier: str | exp.Identifier,
478    dialect: DialectType = None,
479    is_table: bool = False,
480    normalize: t.Optional[bool] = True,
481) -> exp.Identifier:
482    if isinstance(identifier, str):
483        identifier = exp.parse_identifier(identifier, dialect=dialect)
484
485    if not normalize:
486        return identifier
487
488    # this is used for normalize_identifier, bigquery has special rules pertaining tables
489    identifier.meta["is_table"] = is_table
490    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
493def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
494    if isinstance(schema, Schema):
495        return schema
496
497    return MappingSchema(schema, **kwargs)
def ensure_column_mapping(mapping: Union[Dict, str, List, NoneType]) -> Dict:
500def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
501    if mapping is None:
502        return {}
503    elif isinstance(mapping, dict):
504        return mapping
505    elif isinstance(mapping, str):
506        col_name_type_strs = [x.strip() for x in mapping.split(",")]
507        return {
508            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
509            for name_type_str in col_name_type_strs
510        }
511    elif isinstance(mapping, list):
512        return {x.strip(): None for x in mapping}
513
514    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: Optional[int] = None, keys: Optional[List[str]] = None) -> List[List[str]]:
517def flatten_schema(
518    schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
519) -> t.List[t.List[str]]:
520    tables = []
521    keys = keys or []
522    depth = dict_depth(schema) - 1 if depth is None else depth
523
524    for k, v in schema.items():
525        if depth == 1 or not isinstance(v, dict):
526            tables.append(keys + [k])
527        elif depth >= 2:
528            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
529
530    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
533def nested_get(
534    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
535) -> t.Optional[t.Any]:
536    """
537    Get a value for a nested dictionary.
538
539    Args:
540        d: the dictionary to search.
541        *path: tuples of (name, key), where:
542            `key` is the key in the dictionary to get.
543            `name` is a string to use in the error if `key` isn't found.
544
545    Returns:
546        The value or None if it doesn't exist.
547    """
548    for name, key in path:
549        d = d.get(key)  # type: ignore
550        if d is None:
551            if raise_on_missing:
552                name = "table" if name == "this" else name
553                raise ValueError(f"Unknown {name}: {key}")
554            return None
555
556    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
559def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
560    """
561    In-place set a value for a nested dictionary
562
563    Example:
564        >>> nested_set({}, ["top_key", "second_key"], "value")
565        {'top_key': {'second_key': 'value'}}
566
567        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
568        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
569
570    Args:
571        d: dictionary to update.
572        keys: the keys that makeup the path to `value`.
573        value: the value to set in the dictionary for the given key path.
574
575    Returns:
576        The (possibly) updated dictionary.
577    """
578    if not keys:
579        return d
580
581    if len(keys) == 1:
582        d[keys[0]] = value
583        return d
584
585    subd = d
586    for key in keys[:-1]:
587        if key not in subd:
588            subd = subd.setdefault(key, {})
589        else:
590            subd = subd[key]
591
592    subd[keys[-1]] = value
593    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.