Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import SchemaError
  9from sqlglot.helper import dict_depth
 10from sqlglot.trie import TrieResult, in_trie, new_trie
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dataframe.sql.types import StructType
 14    from sqlglot.dialects.dialect import DialectType
 15
 16    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 17
 18
 19class Schema(abc.ABC):
 20    """Abstract base class for database schemas"""
 21
 22    dialect: DialectType
 23
 24    @abc.abstractmethod
 25    def add_table(
 26        self,
 27        table: exp.Table | str,
 28        column_mapping: t.Optional[ColumnMapping] = None,
 29        dialect: DialectType = None,
 30        normalize: t.Optional[bool] = None,
 31        match_depth: bool = True,
 32    ) -> None:
 33        """
 34        Register or update a table. Some implementing classes may require column information to also be provided.
 35        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 36
 37        Args:
 38            table: the `Table` expression instance or string representing the table.
 39            column_mapping: a column mapping that describes the structure of the table.
 40            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 41            normalize: whether to normalize identifiers according to the dialect of interest.
 42            match_depth: whether to enforce that the table must match the schema's depth or not.
 43        """
 44
 45    @abc.abstractmethod
 46    def column_names(
 47        self,
 48        table: exp.Table | str,
 49        only_visible: bool = False,
 50        dialect: DialectType = None,
 51        normalize: t.Optional[bool] = None,
 52    ) -> t.Sequence[str]:
 53        """
 54        Get the column names for a table.
 55
 56        Args:
 57            table: the `Table` expression instance.
 58            only_visible: whether to include invisible columns.
 59            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 60            normalize: whether to normalize identifiers according to the dialect of interest.
 61
 62        Returns:
 63            The sequence of column names.
 64        """
 65
 66    @abc.abstractmethod
 67    def get_column_type(
 68        self,
 69        table: exp.Table | str,
 70        column: exp.Column | str,
 71        dialect: DialectType = None,
 72        normalize: t.Optional[bool] = None,
 73    ) -> exp.DataType:
 74        """
 75        Get the `sqlglot.exp.DataType` type of a column in the schema.
 76
 77        Args:
 78            table: the source table.
 79            column: the target column.
 80            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 81            normalize: whether to normalize identifiers according to the dialect of interest.
 82
 83        Returns:
 84            The resulting column type.
 85        """
 86
 87    def has_column(
 88        self,
 89        table: exp.Table | str,
 90        column: exp.Column | str,
 91        dialect: DialectType = None,
 92        normalize: t.Optional[bool] = None,
 93    ) -> bool:
 94        """
 95        Returns whether `column` appears in `table`'s schema.
 96
 97        Args:
 98            table: the source table.
 99            column: the target column.
100            dialect: the SQL dialect that will be used to parse `table` if it's a string.
101            normalize: whether to normalize identifiers according to the dialect of interest.
102
103        Returns:
104            True if the column appears in the schema, False otherwise.
105        """
106        name = column if isinstance(column, str) else column.name
107        return name in self.column_names(table, dialect=dialect, normalize=normalize)
108
109    @property
110    @abc.abstractmethod
111    def supported_table_args(self) -> t.Tuple[str, ...]:
112        """
113        Table arguments this schema support, e.g. `("this", "db", "catalog")`
114        """
115
116    @property
117    def empty(self) -> bool:
118        """Returns whether the schema is empty."""
119        return True
120
121
122class AbstractMappingSchema:
123    def __init__(
124        self,
125        mapping: t.Optional[t.Dict] = None,
126    ) -> None:
127        self.mapping = mapping or {}
128        self.mapping_trie = new_trie(
129            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
130        )
131        self._supported_table_args: t.Tuple[str, ...] = tuple()
132
133    @property
134    def empty(self) -> bool:
135        return not self.mapping
136
137    def depth(self) -> int:
138        return dict_depth(self.mapping)
139
140    @property
141    def supported_table_args(self) -> t.Tuple[str, ...]:
142        if not self._supported_table_args and self.mapping:
143            depth = self.depth()
144
145            if not depth:  # None
146                self._supported_table_args = tuple()
147            elif 1 <= depth <= 3:
148                self._supported_table_args = exp.TABLE_PARTS[:depth]
149            else:
150                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
151
152        return self._supported_table_args
153
154    def table_parts(self, table: exp.Table) -> t.List[str]:
155        if isinstance(table.this, exp.ReadCSV):
156            return [table.this.name]
157        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
158
159    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
160        """
161        Returns the schema of a given table.
162
163        Args:
164            table: the target table.
165            raise_on_missing: whether to raise in case the schema is not found.
166
167        Returns:
168            The schema of the target table.
169        """
170        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
171        value, trie = in_trie(self.mapping_trie, parts)
172
173        if value == TrieResult.FAILED:
174            return None
175
176        if value == TrieResult.PREFIX:
177            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
178
179            if len(possibilities) == 1:
180                parts.extend(possibilities[0])
181            else:
182                message = ", ".join(".".join(parts) for parts in possibilities)
183                if raise_on_missing:
184                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
185                return None
186
187        return self.nested_get(parts, raise_on_missing=raise_on_missing)
188
189    def nested_get(
190        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
191    ) -> t.Optional[t.Any]:
192        return nested_get(
193            d or self.mapping,
194            *zip(self.supported_table_args, reversed(parts)),
195            raise_on_missing=raise_on_missing,
196        )
197
198
199class MappingSchema(AbstractMappingSchema, Schema):
200    """
201    Schema based on a nested mapping.
202
203    Args:
204        schema: Mapping in one of the following forms:
205            1. {table: {col: type}}
206            2. {db: {table: {col: type}}}
207            3. {catalog: {db: {table: {col: type}}}}
208            4. None - Tables will be added later
209        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
210            are assumed to be visible. The nesting should mirror that of the schema:
211            1. {table: set(*cols)}}
212            2. {db: {table: set(*cols)}}}
213            3. {catalog: {db: {table: set(*cols)}}}}
214        dialect: The dialect to be used for custom type mappings & parsing string arguments.
215        normalize: Whether to normalize identifier names according to the given dialect or not.
216    """
217
218    def __init__(
219        self,
220        schema: t.Optional[t.Dict] = None,
221        visible: t.Optional[t.Dict] = None,
222        dialect: DialectType = None,
223        normalize: bool = True,
224    ) -> None:
225        self.dialect = dialect
226        self.visible = {} if visible is None else visible
227        self.normalize = normalize
228        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
229        self._depth = 0
230        schema = {} if schema is None else schema
231
232        super().__init__(self._normalize(schema) if self.normalize else schema)
233
234    @classmethod
235    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
236        return MappingSchema(
237            schema=mapping_schema.mapping,
238            visible=mapping_schema.visible,
239            dialect=mapping_schema.dialect,
240            normalize=mapping_schema.normalize,
241        )
242
243    def copy(self, **kwargs) -> MappingSchema:
244        return MappingSchema(
245            **{  # type: ignore
246                "schema": self.mapping.copy(),
247                "visible": self.visible.copy(),
248                "dialect": self.dialect,
249                "normalize": self.normalize,
250                **kwargs,
251            }
252        )
253
254    def add_table(
255        self,
256        table: exp.Table | str,
257        column_mapping: t.Optional[ColumnMapping] = None,
258        dialect: DialectType = None,
259        normalize: t.Optional[bool] = None,
260        match_depth: bool = True,
261    ) -> None:
262        """
263        Register or update a table. Updates are only performed if a new column mapping is provided.
264        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
265
266        Args:
267            table: the `Table` expression instance or string representing the table.
268            column_mapping: a column mapping that describes the structure of the table.
269            dialect: the SQL dialect that will be used to parse `table` if it's a string.
270            normalize: whether to normalize identifiers according to the dialect of interest.
271            match_depth: whether to enforce that the table must match the schema's depth or not.
272        """
273        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
274
275        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
276            raise SchemaError(
277                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
278                f"schema's nesting level: {self.depth()}."
279            )
280
281        normalized_column_mapping = {
282            self._normalize_name(key, dialect=dialect, normalize=normalize): value
283            for key, value in ensure_column_mapping(column_mapping).items()
284        }
285
286        schema = self.find(normalized_table, raise_on_missing=False)
287        if schema and not normalized_column_mapping:
288            return
289
290        parts = self.table_parts(normalized_table)
291
292        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
293        new_trie([parts], self.mapping_trie)
294
295    def column_names(
296        self,
297        table: exp.Table | str,
298        only_visible: bool = False,
299        dialect: DialectType = None,
300        normalize: t.Optional[bool] = None,
301    ) -> t.List[str]:
302        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
303
304        schema = self.find(normalized_table)
305        if schema is None:
306            return []
307
308        if not only_visible or not self.visible:
309            return list(schema)
310
311        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
312        return [col for col in schema if col in visible]
313
314    def get_column_type(
315        self,
316        table: exp.Table | str,
317        column: exp.Column | str,
318        dialect: DialectType = None,
319        normalize: t.Optional[bool] = None,
320    ) -> exp.DataType:
321        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
322
323        normalized_column_name = self._normalize_name(
324            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
325        )
326
327        table_schema = self.find(normalized_table, raise_on_missing=False)
328        if table_schema:
329            column_type = table_schema.get(normalized_column_name)
330
331            if isinstance(column_type, exp.DataType):
332                return column_type
333            elif isinstance(column_type, str):
334                return self._to_data_type(column_type, dialect=dialect)
335
336        return exp.DataType.build("unknown")
337
338    def has_column(
339        self,
340        table: exp.Table | str,
341        column: exp.Column | str,
342        dialect: DialectType = None,
343        normalize: t.Optional[bool] = None,
344    ) -> bool:
345        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
346
347        normalized_column_name = self._normalize_name(
348            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
349        )
350
351        table_schema = self.find(normalized_table, raise_on_missing=False)
352        return normalized_column_name in table_schema if table_schema else False
353
354    def _normalize(self, schema: t.Dict) -> t.Dict:
355        """
356        Normalizes all identifiers in the schema.
357
358        Args:
359            schema: the schema to normalize.
360
361        Returns:
362            The normalized schema mapping.
363        """
364        normalized_mapping: t.Dict = {}
365        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
366
367        for keys in flattened_schema:
368            columns = nested_get(schema, *zip(keys, keys))
369
370            if not isinstance(columns, dict):
371                raise SchemaError(
372                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
373                )
374
375            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
376            for column_name, column_type in columns.items():
377                nested_set(
378                    normalized_mapping,
379                    normalized_keys + [self._normalize_name(column_name)],
380                    column_type,
381                )
382
383        return normalized_mapping
384
385    def _normalize_table(
386        self,
387        table: exp.Table | str,
388        dialect: DialectType = None,
389        normalize: t.Optional[bool] = None,
390    ) -> exp.Table:
391        dialect = dialect or self.dialect
392        normalize = self.normalize if normalize is None else normalize
393
394        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
395
396        if normalize:
397            for arg in exp.TABLE_PARTS:
398                value = normalized_table.args.get(arg)
399                if isinstance(value, exp.Identifier):
400                    normalized_table.set(
401                        arg,
402                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
403                    )
404
405        return normalized_table
406
407    def _normalize_name(
408        self,
409        name: str | exp.Identifier,
410        dialect: DialectType = None,
411        is_table: bool = False,
412        normalize: t.Optional[bool] = None,
413    ) -> str:
414        return normalize_name(
415            name,
416            dialect=dialect or self.dialect,
417            is_table=is_table,
418            normalize=self.normalize if normalize is None else normalize,
419        ).name
420
421    def depth(self) -> int:
422        if not self.empty and not self._depth:
423            # The columns themselves are a mapping, but we don't want to include those
424            self._depth = super().depth() - 1
425        return self._depth
426
427    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
428        """
429        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
430
431        Args:
432            schema_type: the type we want to convert.
433            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
434
435        Returns:
436            The resulting expression type.
437        """
438        if schema_type not in self._type_mapping_cache:
439            dialect = dialect or self.dialect
440            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
441
442            try:
443                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
444                self._type_mapping_cache[schema_type] = expression
445            except AttributeError:
446                in_dialect = f" in dialect {dialect}" if dialect else ""
447                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
448
449        return self._type_mapping_cache[schema_type]
450
451
452def normalize_name(
453    identifier: str | exp.Identifier,
454    dialect: DialectType = None,
455    is_table: bool = False,
456    normalize: t.Optional[bool] = True,
457) -> exp.Identifier:
458    if isinstance(identifier, str):
459        identifier = exp.parse_identifier(identifier, dialect=dialect)
460
461    if not normalize:
462        return identifier
463
464    # this is used for normalize_identifier, bigquery has special rules pertaining tables
465    identifier.meta["is_table"] = is_table
466    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
467
468
469def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
470    if isinstance(schema, Schema):
471        return schema
472
473    return MappingSchema(schema, **kwargs)
474
475
476def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
477    if mapping is None:
478        return {}
479    elif isinstance(mapping, dict):
480        return mapping
481    elif isinstance(mapping, str):
482        col_name_type_strs = [x.strip() for x in mapping.split(",")]
483        return {
484            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
485            for name_type_str in col_name_type_strs
486        }
487    # Check if mapping looks like a DataFrame StructType
488    elif hasattr(mapping, "simpleString"):
489        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
490    elif isinstance(mapping, list):
491        return {x.strip(): None for x in mapping}
492
493    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
494
495
496def flatten_schema(
497    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
498) -> t.List[t.List[str]]:
499    tables = []
500    keys = keys or []
501
502    for k, v in schema.items():
503        if depth >= 2:
504            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
505        elif depth == 1:
506            tables.append(keys + [k])
507
508    return tables
509
510
511def nested_get(
512    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
513) -> t.Optional[t.Any]:
514    """
515    Get a value for a nested dictionary.
516
517    Args:
518        d: the dictionary to search.
519        *path: tuples of (name, key), where:
520            `key` is the key in the dictionary to get.
521            `name` is a string to use in the error if `key` isn't found.
522
523    Returns:
524        The value or None if it doesn't exist.
525    """
526    for name, key in path:
527        d = d.get(key)  # type: ignore
528        if d is None:
529            if raise_on_missing:
530                name = "table" if name == "this" else name
531                raise ValueError(f"Unknown {name}: {key}")
532            return None
533
534    return d
535
536
537def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
538    """
539    In-place set a value for a nested dictionary
540
541    Example:
542        >>> nested_set({}, ["top_key", "second_key"], "value")
543        {'top_key': {'second_key': 'value'}}
544
545        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
546        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
547
548    Args:
549        d: dictionary to update.
550        keys: the keys that makeup the path to `value`.
551        value: the value to set in the dictionary for the given key path.
552
553    Returns:
554        The (possibly) updated dictionary.
555    """
556    if not keys:
557        return d
558
559    if len(keys) == 1:
560        d[keys[0]] = value
561        return d
562
563    subd = d
564    for key in keys[:-1]:
565        if key not in subd:
566            subd = subd.setdefault(key, {})
567        else:
568            subd = subd[key]
569
570    subd[keys[-1]] = value
571    return d
class Schema(abc.ABC):
 20class Schema(abc.ABC):
 21    """Abstract base class for database schemas"""
 22
 23    dialect: DialectType
 24
 25    @abc.abstractmethod
 26    def add_table(
 27        self,
 28        table: exp.Table | str,
 29        column_mapping: t.Optional[ColumnMapping] = None,
 30        dialect: DialectType = None,
 31        normalize: t.Optional[bool] = None,
 32        match_depth: bool = True,
 33    ) -> None:
 34        """
 35        Register or update a table. Some implementing classes may require column information to also be provided.
 36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
 37
 38        Args:
 39            table: the `Table` expression instance or string representing the table.
 40            column_mapping: a column mapping that describes the structure of the table.
 41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 42            normalize: whether to normalize identifiers according to the dialect of interest.
 43            match_depth: whether to enforce that the table must match the schema's depth or not.
 44        """
 45
 46    @abc.abstractmethod
 47    def column_names(
 48        self,
 49        table: exp.Table | str,
 50        only_visible: bool = False,
 51        dialect: DialectType = None,
 52        normalize: t.Optional[bool] = None,
 53    ) -> t.Sequence[str]:
 54        """
 55        Get the column names for a table.
 56
 57        Args:
 58            table: the `Table` expression instance.
 59            only_visible: whether to include invisible columns.
 60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 61            normalize: whether to normalize identifiers according to the dialect of interest.
 62
 63        Returns:
 64            The sequence of column names.
 65        """
 66
 67    @abc.abstractmethod
 68    def get_column_type(
 69        self,
 70        table: exp.Table | str,
 71        column: exp.Column | str,
 72        dialect: DialectType = None,
 73        normalize: t.Optional[bool] = None,
 74    ) -> exp.DataType:
 75        """
 76        Get the `sqlglot.exp.DataType` type of a column in the schema.
 77
 78        Args:
 79            table: the source table.
 80            column: the target column.
 81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 82            normalize: whether to normalize identifiers according to the dialect of interest.
 83
 84        Returns:
 85            The resulting column type.
 86        """
 87
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)
109
110    @property
111    @abc.abstractmethod
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        """
114        Table arguments this schema support, e.g. `("this", "db", "catalog")`
115        """
116
117    @property
118    def empty(self) -> bool:
119        """Returns whether the schema is empty."""
120        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
25    @abc.abstractmethod
26    def add_table(
27        self,
28        table: exp.Table | str,
29        column_mapping: t.Optional[ColumnMapping] = None,
30        dialect: DialectType = None,
31        normalize: t.Optional[bool] = None,
32        match_depth: bool = True,
33    ) -> None:
34        """
35        Register or update a table. Some implementing classes may require column information to also be provided.
36        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
37
38        Args:
39            table: the `Table` expression instance or string representing the table.
40            column_mapping: a column mapping that describes the structure of the table.
41            dialect: the SQL dialect that will be used to parse `table` if it's a string.
42            normalize: whether to normalize identifiers according to the dialect of interest.
43            match_depth: whether to enforce that the table must match the schema's depth or not.
44        """

Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> Sequence[str]:
46    @abc.abstractmethod
47    def column_names(
48        self,
49        table: exp.Table | str,
50        only_visible: bool = False,
51        dialect: DialectType = None,
52        normalize: t.Optional[bool] = None,
53    ) -> t.Sequence[str]:
54        """
55        Get the column names for a table.
56
57        Args:
58            table: the `Table` expression instance.
59            only_visible: whether to include invisible columns.
60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
61            normalize: whether to normalize identifiers according to the dialect of interest.
62
63        Returns:
64            The sequence of column names.
65        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
67    @abc.abstractmethod
68    def get_column_type(
69        self,
70        table: exp.Table | str,
71        column: exp.Column | str,
72        dialect: DialectType = None,
73        normalize: t.Optional[bool] = None,
74    ) -> exp.DataType:
75        """
76        Get the `sqlglot.exp.DataType` type of a column in the schema.
77
78        Args:
79            table: the source table.
80            column: the target column.
81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
82            normalize: whether to normalize identifiers according to the dialect of interest.
83
84        Returns:
85            The resulting column type.
86        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
 88    def has_column(
 89        self,
 90        table: exp.Table | str,
 91        column: exp.Column | str,
 92        dialect: DialectType = None,
 93        normalize: t.Optional[bool] = None,
 94    ) -> bool:
 95        """
 96        Returns whether `column` appears in `table`'s schema.
 97
 98        Args:
 99            table: the source table.
100            column: the target column.
101            dialect: the SQL dialect that will be used to parse `table` if it's a string.
102            normalize: whether to normalize identifiers according to the dialect of interest.
103
104        Returns:
105            True if the column appears in the schema, False otherwise.
106        """
107        name = column if isinstance(column, str) else column.name
108        return name in self.column_names(table, dialect=dialect, normalize=normalize)

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

supported_table_args: Tuple[str, ...]
110    @property
111    @abc.abstractmethod
112    def supported_table_args(self) -> t.Tuple[str, ...]:
113        """
114        Table arguments this schema support, e.g. `("this", "db", "catalog")`
115        """

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool
117    @property
118    def empty(self) -> bool:
119        """Returns whether the schema is empty."""
120        return True

Returns whether the schema is empty.

class AbstractMappingSchema:
123class AbstractMappingSchema:
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
133
134    @property
135    def empty(self) -> bool:
136        return not self.mapping
137
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
140
141    @property
142    def supported_table_args(self) -> t.Tuple[str, ...]:
143        if not self._supported_table_args and self.mapping:
144            depth = self.depth()
145
146            if not depth:  # None
147                self._supported_table_args = tuple()
148            elif 1 <= depth <= 3:
149                self._supported_table_args = exp.TABLE_PARTS[:depth]
150            else:
151                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
152
153        return self._supported_table_args
154
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
159
160    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
161        """
162        Returns the schema of a given table.
163
164        Args:
165            table: the target table.
166            raise_on_missing: whether to raise in case the schema is not found.
167
168        Returns:
169            The schema of the target table.
170        """
171        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
172        value, trie = in_trie(self.mapping_trie, parts)
173
174        if value == TrieResult.FAILED:
175            return None
176
177        if value == TrieResult.PREFIX:
178            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
179
180            if len(possibilities) == 1:
181                parts.extend(possibilities[0])
182            else:
183                message = ", ".join(".".join(parts) for parts in possibilities)
184                if raise_on_missing:
185                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
186                return None
187
188        return self.nested_get(parts, raise_on_missing=raise_on_missing)
189
190    def nested_get(
191        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
192    ) -> t.Optional[t.Any]:
193        return nested_get(
194            d or self.mapping,
195            *zip(self.supported_table_args, reversed(parts)),
196            raise_on_missing=raise_on_missing,
197        )
AbstractMappingSchema(mapping: Optional[Dict] = None)
124    def __init__(
125        self,
126        mapping: t.Optional[t.Dict] = None,
127    ) -> None:
128        self.mapping = mapping or {}
129        self.mapping_trie = new_trie(
130            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
131        )
132        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
134    @property
135    def empty(self) -> bool:
136        return not self.mapping
def depth(self) -> int:
138    def depth(self) -> int:
139        return dict_depth(self.mapping)
supported_table_args: Tuple[str, ...]
141    @property
142    def supported_table_args(self) -> t.Tuple[str, ...]:
143        if not self._supported_table_args and self.mapping:
144            depth = self.depth()
145
146            if not depth:  # None
147                self._supported_table_args = tuple()
148            elif 1 <= depth <= 3:
149                self._supported_table_args = exp.TABLE_PARTS[:depth]
150            else:
151                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
152
153        return self._supported_table_args
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
155    def table_parts(self, table: exp.Table) -> t.List[str]:
156        if isinstance(table.this, exp.ReadCSV):
157            return [table.this.name]
158        return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, raise_on_missing: bool = True) -> Optional[Any]:
160    def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
161        """
162        Returns the schema of a given table.
163
164        Args:
165            table: the target table.
166            raise_on_missing: whether to raise in case the schema is not found.
167
168        Returns:
169            The schema of the target table.
170        """
171        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
172        value, trie = in_trie(self.mapping_trie, parts)
173
174        if value == TrieResult.FAILED:
175            return None
176
177        if value == TrieResult.PREFIX:
178            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
179
180            if len(possibilities) == 1:
181                parts.extend(possibilities[0])
182            else:
183                message = ", ".join(".".join(parts) for parts in possibilities)
184                if raise_on_missing:
185                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
186                return None
187
188        return self.nested_get(parts, raise_on_missing=raise_on_missing)

Returns the schema of a given table.

Arguments:
  • table: the target table.
  • raise_on_missing: whether to raise in case the schema is not found.
Returns:

The schema of the target table.

def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
190    def nested_get(
191        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
192    ) -> t.Optional[t.Any]:
193        return nested_get(
194            d or self.mapping,
195            *zip(self.supported_table_args, reversed(parts)),
196            raise_on_missing=raise_on_missing,
197        )
class MappingSchema(AbstractMappingSchema, Schema):
200class MappingSchema(AbstractMappingSchema, Schema):
201    """
202    Schema based on a nested mapping.
203
204    Args:
205        schema: Mapping in one of the following forms:
206            1. {table: {col: type}}
207            2. {db: {table: {col: type}}}
208            3. {catalog: {db: {table: {col: type}}}}
209            4. None - Tables will be added later
210        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
211            are assumed to be visible. The nesting should mirror that of the schema:
212            1. {table: set(*cols)}}
213            2. {db: {table: set(*cols)}}}
214            3. {catalog: {db: {table: set(*cols)}}}}
215        dialect: The dialect to be used for custom type mappings & parsing string arguments.
216        normalize: Whether to normalize identifier names according to the given dialect or not.
217    """
218
219    def __init__(
220        self,
221        schema: t.Optional[t.Dict] = None,
222        visible: t.Optional[t.Dict] = None,
223        dialect: DialectType = None,
224        normalize: bool = True,
225    ) -> None:
226        self.dialect = dialect
227        self.visible = {} if visible is None else visible
228        self.normalize = normalize
229        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
230        self._depth = 0
231        schema = {} if schema is None else schema
232
233        super().__init__(self._normalize(schema) if self.normalize else schema)
234
235    @classmethod
236    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
237        return MappingSchema(
238            schema=mapping_schema.mapping,
239            visible=mapping_schema.visible,
240            dialect=mapping_schema.dialect,
241            normalize=mapping_schema.normalize,
242        )
243
244    def copy(self, **kwargs) -> MappingSchema:
245        return MappingSchema(
246            **{  # type: ignore
247                "schema": self.mapping.copy(),
248                "visible": self.visible.copy(),
249                "dialect": self.dialect,
250                "normalize": self.normalize,
251                **kwargs,
252            }
253        )
254
255    def add_table(
256        self,
257        table: exp.Table | str,
258        column_mapping: t.Optional[ColumnMapping] = None,
259        dialect: DialectType = None,
260        normalize: t.Optional[bool] = None,
261        match_depth: bool = True,
262    ) -> None:
263        """
264        Register or update a table. Updates are only performed if a new column mapping is provided.
265        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
266
267        Args:
268            table: the `Table` expression instance or string representing the table.
269            column_mapping: a column mapping that describes the structure of the table.
270            dialect: the SQL dialect that will be used to parse `table` if it's a string.
271            normalize: whether to normalize identifiers according to the dialect of interest.
272            match_depth: whether to enforce that the table must match the schema's depth or not.
273        """
274        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
275
276        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
277            raise SchemaError(
278                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
279                f"schema's nesting level: {self.depth()}."
280            )
281
282        normalized_column_mapping = {
283            self._normalize_name(key, dialect=dialect, normalize=normalize): value
284            for key, value in ensure_column_mapping(column_mapping).items()
285        }
286
287        schema = self.find(normalized_table, raise_on_missing=False)
288        if schema and not normalized_column_mapping:
289            return
290
291        parts = self.table_parts(normalized_table)
292
293        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
294        new_trie([parts], self.mapping_trie)
295
296    def column_names(
297        self,
298        table: exp.Table | str,
299        only_visible: bool = False,
300        dialect: DialectType = None,
301        normalize: t.Optional[bool] = None,
302    ) -> t.List[str]:
303        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
304
305        schema = self.find(normalized_table)
306        if schema is None:
307            return []
308
309        if not only_visible or not self.visible:
310            return list(schema)
311
312        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
313        return [col for col in schema if col in visible]
314
315    def get_column_type(
316        self,
317        table: exp.Table | str,
318        column: exp.Column | str,
319        dialect: DialectType = None,
320        normalize: t.Optional[bool] = None,
321    ) -> exp.DataType:
322        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
323
324        normalized_column_name = self._normalize_name(
325            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
326        )
327
328        table_schema = self.find(normalized_table, raise_on_missing=False)
329        if table_schema:
330            column_type = table_schema.get(normalized_column_name)
331
332            if isinstance(column_type, exp.DataType):
333                return column_type
334            elif isinstance(column_type, str):
335                return self._to_data_type(column_type, dialect=dialect)
336
337        return exp.DataType.build("unknown")
338
339    def has_column(
340        self,
341        table: exp.Table | str,
342        column: exp.Column | str,
343        dialect: DialectType = None,
344        normalize: t.Optional[bool] = None,
345    ) -> bool:
346        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
347
348        normalized_column_name = self._normalize_name(
349            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
350        )
351
352        table_schema = self.find(normalized_table, raise_on_missing=False)
353        return normalized_column_name in table_schema if table_schema else False
354
355    def _normalize(self, schema: t.Dict) -> t.Dict:
356        """
357        Normalizes all identifiers in the schema.
358
359        Args:
360            schema: the schema to normalize.
361
362        Returns:
363            The normalized schema mapping.
364        """
365        normalized_mapping: t.Dict = {}
366        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
367
368        for keys in flattened_schema:
369            columns = nested_get(schema, *zip(keys, keys))
370
371            if not isinstance(columns, dict):
372                raise SchemaError(
373                    f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
374                )
375
376            normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
377            for column_name, column_type in columns.items():
378                nested_set(
379                    normalized_mapping,
380                    normalized_keys + [self._normalize_name(column_name)],
381                    column_type,
382                )
383
384        return normalized_mapping
385
386    def _normalize_table(
387        self,
388        table: exp.Table | str,
389        dialect: DialectType = None,
390        normalize: t.Optional[bool] = None,
391    ) -> exp.Table:
392        dialect = dialect or self.dialect
393        normalize = self.normalize if normalize is None else normalize
394
395        normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize)
396
397        if normalize:
398            for arg in exp.TABLE_PARTS:
399                value = normalized_table.args.get(arg)
400                if isinstance(value, exp.Identifier):
401                    normalized_table.set(
402                        arg,
403                        normalize_name(value, dialect=dialect, is_table=True, normalize=normalize),
404                    )
405
406        return normalized_table
407
408    def _normalize_name(
409        self,
410        name: str | exp.Identifier,
411        dialect: DialectType = None,
412        is_table: bool = False,
413        normalize: t.Optional[bool] = None,
414    ) -> str:
415        return normalize_name(
416            name,
417            dialect=dialect or self.dialect,
418            is_table=is_table,
419            normalize=self.normalize if normalize is None else normalize,
420        ).name
421
422    def depth(self) -> int:
423        if not self.empty and not self._depth:
424            # The columns themselves are a mapping, but we don't want to include those
425            self._depth = super().depth() - 1
426        return self._depth
427
428    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
429        """
430        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
431
432        Args:
433            schema_type: the type we want to convert.
434            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
435
436        Returns:
437            The resulting expression type.
438        """
439        if schema_type not in self._type_mapping_cache:
440            dialect = dialect or self.dialect
441            udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
442
443            try:
444                expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
445                self._type_mapping_cache[schema_type] = expression
446            except AttributeError:
447                in_dialect = f" in dialect {dialect}" if dialect else ""
448                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
449
450        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
219    def __init__(
220        self,
221        schema: t.Optional[t.Dict] = None,
222        visible: t.Optional[t.Dict] = None,
223        dialect: DialectType = None,
224        normalize: bool = True,
225    ) -> None:
226        self.dialect = dialect
227        self.visible = {} if visible is None else visible
228        self.normalize = normalize
229        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
230        self._depth = 0
231        schema = {} if schema is None else schema
232
233        super().__init__(self._normalize(schema) if self.normalize else schema)
dialect
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: MappingSchema) -> MappingSchema:
235    @classmethod
236    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
237        return MappingSchema(
238            schema=mapping_schema.mapping,
239            visible=mapping_schema.visible,
240            dialect=mapping_schema.dialect,
241            normalize=mapping_schema.normalize,
242        )
def copy(self, **kwargs) -> MappingSchema:
244    def copy(self, **kwargs) -> MappingSchema:
245        return MappingSchema(
246            **{  # type: ignore
247                "schema": self.mapping.copy(),
248                "visible": self.visible.copy(),
249                "dialect": self.dialect,
250                "normalize": self.normalize,
251                **kwargs,
252            }
253        )
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None, match_depth: bool = True) -> None:
255    def add_table(
256        self,
257        table: exp.Table | str,
258        column_mapping: t.Optional[ColumnMapping] = None,
259        dialect: DialectType = None,
260        normalize: t.Optional[bool] = None,
261        match_depth: bool = True,
262    ) -> None:
263        """
264        Register or update a table. Updates are only performed if a new column mapping is provided.
265        The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
266
267        Args:
268            table: the `Table` expression instance or string representing the table.
269            column_mapping: a column mapping that describes the structure of the table.
270            dialect: the SQL dialect that will be used to parse `table` if it's a string.
271            normalize: whether to normalize identifiers according to the dialect of interest.
272            match_depth: whether to enforce that the table must match the schema's depth or not.
273        """
274        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
275
276        if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
277            raise SchemaError(
278                f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
279                f"schema's nesting level: {self.depth()}."
280            )
281
282        normalized_column_mapping = {
283            self._normalize_name(key, dialect=dialect, normalize=normalize): value
284            for key, value in ensure_column_mapping(column_mapping).items()
285        }
286
287        schema = self.find(normalized_table, raise_on_missing=False)
288        if schema and not normalized_column_mapping:
289            return
290
291        parts = self.table_parts(normalized_table)
292
293        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
294        new_trie([parts], self.mapping_trie)

Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
  • match_depth: whether to enforce that the table must match the schema's depth or not.
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
296    def column_names(
297        self,
298        table: exp.Table | str,
299        only_visible: bool = False,
300        dialect: DialectType = None,
301        normalize: t.Optional[bool] = None,
302    ) -> t.List[str]:
303        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
304
305        schema = self.find(normalized_table)
306        if schema is None:
307            return []
308
309        if not only_visible or not self.visible:
310            return list(schema)
311
312        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
313        return [col for col in schema if col in visible]

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The sequence of column names.

def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
315    def get_column_type(
316        self,
317        table: exp.Table | str,
318        column: exp.Column | str,
319        dialect: DialectType = None,
320        normalize: t.Optional[bool] = None,
321    ) -> exp.DataType:
322        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
323
324        normalized_column_name = self._normalize_name(
325            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
326        )
327
328        table_schema = self.find(normalized_table, raise_on_missing=False)
329        if table_schema:
330            column_type = table_schema.get(normalized_column_name)
331
332            if isinstance(column_type, exp.DataType):
333                return column_type
334            elif isinstance(column_type, str):
335                return self._to_data_type(column_type, dialect=dialect)
336
337        return exp.DataType.build("unknown")

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

def has_column( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column | str, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> bool:
339    def has_column(
340        self,
341        table: exp.Table | str,
342        column: exp.Column | str,
343        dialect: DialectType = None,
344        normalize: t.Optional[bool] = None,
345    ) -> bool:
346        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
347
348        normalized_column_name = self._normalize_name(
349            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
350        )
351
352        table_schema = self.find(normalized_table, raise_on_missing=False)
353        return normalized_column_name in table_schema if table_schema else False

Returns whether column appears in table's schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

True if the column appears in the schema, False otherwise.

def depth(self) -> int:
422    def depth(self) -> int:
423        if not self.empty and not self._depth:
424            # The columns themselves are a mapping, but we don't want to include those
425            self._depth = super().depth() - 1
426        return self._depth
def normalize_name( identifier: str | sqlglot.expressions.Identifier, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, is_table: bool = False, normalize: Optional[bool] = True) -> sqlglot.expressions.Identifier:
453def normalize_name(
454    identifier: str | exp.Identifier,
455    dialect: DialectType = None,
456    is_table: bool = False,
457    normalize: t.Optional[bool] = True,
458) -> exp.Identifier:
459    if isinstance(identifier, str):
460        identifier = exp.parse_identifier(identifier, dialect=dialect)
461
462    if not normalize:
463        return identifier
464
465    # this is used for normalize_identifier, bigquery has special rules pertaining tables
466    identifier.meta["is_table"] = is_table
467    return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
def ensure_schema( schema: Union[Schema, Dict, NoneType], **kwargs: Any) -> Schema:
470def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
471    if isinstance(schema, Schema):
472        return schema
473
474    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
477def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
478    if mapping is None:
479        return {}
480    elif isinstance(mapping, dict):
481        return mapping
482    elif isinstance(mapping, str):
483        col_name_type_strs = [x.strip() for x in mapping.split(",")]
484        return {
485            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
486            for name_type_str in col_name_type_strs
487        }
488    # Check if mapping looks like a DataFrame StructType
489    elif hasattr(mapping, "simpleString"):
490        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
491    elif isinstance(mapping, list):
492        return {x.strip(): None for x in mapping}
493
494    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
497def flatten_schema(
498    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
499) -> t.List[t.List[str]]:
500    tables = []
501    keys = keys or []
502
503    for k, v in schema.items():
504        if depth >= 2:
505            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
506        elif depth == 1:
507            tables.append(keys + [k])
508
509    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
512def nested_get(
513    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
514) -> t.Optional[t.Any]:
515    """
516    Get a value for a nested dictionary.
517
518    Args:
519        d: the dictionary to search.
520        *path: tuples of (name, key), where:
521            `key` is the key in the dictionary to get.
522            `name` is a string to use in the error if `key` isn't found.
523
524    Returns:
525        The value or None if it doesn't exist.
526    """
527    for name, key in path:
528        d = d.get(key)  # type: ignore
529        if d is None:
530            if raise_on_missing:
531                name = "table" if name == "this" else name
532                raise ValueError(f"Unknown {name}: {key}")
533            return None
534
535    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
538def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
539    """
540    In-place set a value for a nested dictionary
541
542    Example:
543        >>> nested_set({}, ["top_key", "second_key"], "value")
544        {'top_key': {'second_key': 'value'}}
545
546        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
547        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
548
549    Args:
550        d: dictionary to update.
551        keys: the keys that makeup the path to `value`.
552        value: the value to set in the dictionary for the given key path.
553
554    Returns:
555        The (possibly) updated dictionary.
556    """
557    if not keys:
558        return d
559
560    if len(keys) == 1:
561        d[keys[0]] = value
562        return d
563
564    subd = d
565    for key in keys[:-1]:
566        if key not in subd:
567            subd = subd.setdefault(key, {})
568        else:
569            subd = subd[key]
570
571    subd[keys[-1]] = value
572    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.