Edit on GitHub

sqlglot.optimizer.qualify_tables

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import DialectType
  8from sqlglot.helper import csv_reader, name_sequence
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10from sqlglot.schema import Schema
 11from sqlglot.dialects.dialect import Dialect
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot._typing import E
 15
 16
 17def qualify_tables(
 18    expression: E,
 19    db: t.Optional[str | exp.Identifier] = None,
 20    catalog: t.Optional[str | exp.Identifier] = None,
 21    schema: t.Optional[Schema] = None,
 22    infer_csv_schemas: bool = False,
 23    dialect: DialectType = None,
 24) -> E:
 25    """
 26    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 27    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 28
 29    Examples:
 30        >>> import sqlglot
 31        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 32        >>> qualify_tables(expression, db="db").sql()
 33        'SELECT 1 FROM db.tbl AS tbl'
 34        >>>
 35        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 36        >>> qualify_tables(expression).sql()
 37        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 38
 39    Args:
 40        expression: Expression to qualify
 41        db: Database name
 42        catalog: Catalog name
 43        schema: A schema to populate
 44        infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
 45        dialect: The dialect to parse catalog and schema into.
 46
 47    Returns:
 48        The qualified expression.
 49    """
 50    next_alias_name = name_sequence("_q_")
 51    db = exp.parse_identifier(db, dialect=dialect) if db else None
 52    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 53    dialect = Dialect.get_or_raise(dialect)
 54
 55    def _qualify(table: exp.Table) -> None:
 56        if isinstance(table.this, exp.Identifier):
 57            if db and not table.args.get("db"):
 58                table.set("db", db.copy())
 59            if catalog and not table.args.get("catalog") and table.args.get("db"):
 60                table.set("catalog", catalog.copy())
 61
 62    if (db or catalog) and not isinstance(expression, exp.Query):
 63        with_ = expression.args.get("with") or exp.With()
 64        cte_names = {cte.alias_or_name for cte in with_.expressions}
 65
 66        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 67            if isinstance(node, exp.Table) and node.name not in cte_names:
 68                _qualify(node)
 69
 70    for scope in traverse_scope(expression):
 71        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 72            if isinstance(derived_table, exp.Subquery):
 73                unnested = derived_table.unnest()
 74                if isinstance(unnested, exp.Table):
 75                    joins = unnested.args.pop("joins", None)
 76                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 77                    derived_table.this.set("joins", joins)
 78
 79            if not derived_table.args.get("alias"):
 80                alias_ = next_alias_name()
 81                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 82                scope.rename_source(None, alias_)
 83
 84            pivots = derived_table.args.get("pivots")
 85            if pivots and not pivots[0].alias:
 86                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 87
 88        table_aliases = {}
 89
 90        for name, source in scope.sources.items():
 91            if isinstance(source, exp.Table):
 92                pivots = source.args.get("pivots")
 93                if not source.alias:
 94                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 95                    if pivots and pivots[0].alias == name:
 96                        name = source.name
 97
 98                    # Mutates the source by attaching an alias to it
 99                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
100
101                table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
102                    source.alias
103                )
104
105                if pivots:
106                    pivot = pivots[0]
107                    if not pivot.alias:
108                        pivot_alias = source.alias if pivot.unpivot else next_alias_name()
109                        pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
110
111                    # This case corresponds to a pivoted CTE, we don't want to qualify that
112                    if isinstance(scope.sources.get(source.alias_or_name), Scope):
113                        continue
114
115                _qualify(source)
116
117                if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
118                    with csv_reader(source.this) as reader:
119                        header = next(reader)
120                        columns = next(reader)
121                        schema.add_table(
122                            source,
123                            {k: type(v).__name__ for k, v in zip(header, columns)},
124                            match_depth=False,
125                        )
126            elif isinstance(source, Scope) and source.is_udtf:
127                udtf = source.expression
128                table_alias = udtf.args.get("alias") or exp.TableAlias(
129                    this=exp.to_identifier(next_alias_name())
130                )
131                if (
132                    isinstance(udtf, exp.Unnest)
133                    and dialect.UNNEST_COLUMN_ONLY
134                    and not table_alias.columns
135                ):
136                    table_alias.set("columns", [table_alias.this.copy()])
137                    table_alias.set("column_only", True)
138
139                udtf.set("alias", table_alias)
140
141                if not table_alias.name:
142                    table_alias.set("this", exp.to_identifier(next_alias_name()))
143                if isinstance(udtf, exp.Values) and not table_alias.columns:
144                    column_aliases = dialect.generate_values_aliases(udtf)
145                    table_alias.set("columns", column_aliases)
146            else:
147                for node in scope.walk():
148                    if (
149                        isinstance(node, exp.Table)
150                        and not node.alias
151                        and isinstance(node.parent, (exp.From, exp.Join))
152                    ):
153                        # Mutates the table by attaching an alias to it
154                        alias(node, node.name, copy=False, table=True)
155
156        for column in scope.columns:
157            if column.db:
158                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
159
160                if table_alias:
161                    for p in exp.COLUMN_PARTS[1:]:
162                        column.set(p, None)
163
164                    column.set("table", table_alias.copy())
165
166    return expression
def qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, infer_csv_schemas: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
 18def qualify_tables(
 19    expression: E,
 20    db: t.Optional[str | exp.Identifier] = None,
 21    catalog: t.Optional[str | exp.Identifier] = None,
 22    schema: t.Optional[Schema] = None,
 23    infer_csv_schemas: bool = False,
 24    dialect: DialectType = None,
 25) -> E:
 26    """
 27    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 28    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 29
 30    Examples:
 31        >>> import sqlglot
 32        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 33        >>> qualify_tables(expression, db="db").sql()
 34        'SELECT 1 FROM db.tbl AS tbl'
 35        >>>
 36        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 37        >>> qualify_tables(expression).sql()
 38        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 39
 40    Args:
 41        expression: Expression to qualify
 42        db: Database name
 43        catalog: Catalog name
 44        schema: A schema to populate
 45        infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
 46        dialect: The dialect to parse catalog and schema into.
 47
 48    Returns:
 49        The qualified expression.
 50    """
 51    next_alias_name = name_sequence("_q_")
 52    db = exp.parse_identifier(db, dialect=dialect) if db else None
 53    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 54    dialect = Dialect.get_or_raise(dialect)
 55
 56    def _qualify(table: exp.Table) -> None:
 57        if isinstance(table.this, exp.Identifier):
 58            if db and not table.args.get("db"):
 59                table.set("db", db.copy())
 60            if catalog and not table.args.get("catalog") and table.args.get("db"):
 61                table.set("catalog", catalog.copy())
 62
 63    if (db or catalog) and not isinstance(expression, exp.Query):
 64        with_ = expression.args.get("with") or exp.With()
 65        cte_names = {cte.alias_or_name for cte in with_.expressions}
 66
 67        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 68            if isinstance(node, exp.Table) and node.name not in cte_names:
 69                _qualify(node)
 70
 71    for scope in traverse_scope(expression):
 72        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 73            if isinstance(derived_table, exp.Subquery):
 74                unnested = derived_table.unnest()
 75                if isinstance(unnested, exp.Table):
 76                    joins = unnested.args.pop("joins", None)
 77                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 78                    derived_table.this.set("joins", joins)
 79
 80            if not derived_table.args.get("alias"):
 81                alias_ = next_alias_name()
 82                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 83                scope.rename_source(None, alias_)
 84
 85            pivots = derived_table.args.get("pivots")
 86            if pivots and not pivots[0].alias:
 87                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 88
 89        table_aliases = {}
 90
 91        for name, source in scope.sources.items():
 92            if isinstance(source, exp.Table):
 93                pivots = source.args.get("pivots")
 94                if not source.alias:
 95                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 96                    if pivots and pivots[0].alias == name:
 97                        name = source.name
 98
 99                    # Mutates the source by attaching an alias to it
100                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
101
102                table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
103                    source.alias
104                )
105
106                if pivots:
107                    pivot = pivots[0]
108                    if not pivot.alias:
109                        pivot_alias = source.alias if pivot.unpivot else next_alias_name()
110                        pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
111
112                    # This case corresponds to a pivoted CTE, we don't want to qualify that
113                    if isinstance(scope.sources.get(source.alias_or_name), Scope):
114                        continue
115
116                _qualify(source)
117
118                if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
119                    with csv_reader(source.this) as reader:
120                        header = next(reader)
121                        columns = next(reader)
122                        schema.add_table(
123                            source,
124                            {k: type(v).__name__ for k, v in zip(header, columns)},
125                            match_depth=False,
126                        )
127            elif isinstance(source, Scope) and source.is_udtf:
128                udtf = source.expression
129                table_alias = udtf.args.get("alias") or exp.TableAlias(
130                    this=exp.to_identifier(next_alias_name())
131                )
132                if (
133                    isinstance(udtf, exp.Unnest)
134                    and dialect.UNNEST_COLUMN_ONLY
135                    and not table_alias.columns
136                ):
137                    table_alias.set("columns", [table_alias.this.copy()])
138                    table_alias.set("column_only", True)
139
140                udtf.set("alias", table_alias)
141
142                if not table_alias.name:
143                    table_alias.set("this", exp.to_identifier(next_alias_name()))
144                if isinstance(udtf, exp.Values) and not table_alias.columns:
145                    column_aliases = dialect.generate_values_aliases(udtf)
146                    table_alias.set("columns", column_aliases)
147            else:
148                for node in scope.walk():
149                    if (
150                        isinstance(node, exp.Table)
151                        and not node.alias
152                        and isinstance(node.parent, (exp.From, exp.Join))
153                    ):
154                        # Mutates the table by attaching an alias to it
155                        alias(node, node.name, copy=False, table=True)
156
157        for column in scope.columns:
158            if column.db:
159                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
160
161                if table_alias:
162                    for p in exp.COLUMN_PARTS[1:]:
163                        column.set(p, None)
164
165                    column.set("table", table_alias.copy())
166
167    return expression

Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
>>> qualify_tables(expression).sql()
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
  • expression: Expression to qualify
  • db: Database name
  • catalog: Catalog name
  • schema: A schema to populate
  • infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
  • dialect: The dialect to parse catalog and schema into.
Returns:

The qualified expression.