Edit on GitHub

sqlglot.optimizer.qualify_tables

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot.dialects.dialect import Dialect, DialectType
  7from sqlglot.helper import name_sequence, seq_get, ensure_list
  8from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot._typing import E
 13    from collections.abc import Sequence
 14
 15
 16def qualify_tables(
 17    expression: E,
 18    db: str | exp.Identifier | None = None,
 19    catalog: str | exp.Identifier | None = None,
 20    on_qualify: t.Callable[[exp.Table], None] | None = None,
 21    dialect: DialectType = None,
 22    canonicalize_table_aliases: bool = False,
 23) -> E:
 24    """
 25    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 26    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 27
 28    Examples:
 29        >>> import sqlglot
 30        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 31        >>> qualify_tables(expression, db="db").sql()
 32        'SELECT 1 FROM db.tbl AS tbl'
 33        >>>
 34        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 35        >>> qualify_tables(expression).sql()
 36        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 37
 38    Args:
 39        expression: Expr to qualify
 40        db: Database name
 41        catalog: Catalog name
 42        on_qualify: Callback after a table has been qualified.
 43        dialect: The dialect to parse catalog and schema into.
 44        canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
 45            instead of preserving table names. Defaults to False.
 46
 47    Returns:
 48        The qualified expression.
 49    """
 50    dialect = Dialect.get_or_raise(dialect)
 51    next_alias_name = name_sequence("_")
 52
 53    if db := db or None:
 54        db = exp.parse_identifier(db, dialect=dialect)
 55        db.meta["is_table"] = True
 56        db = normalize_identifiers(db, dialect=dialect)
 57    if catalog := catalog or None:
 58        catalog = exp.parse_identifier(catalog, dialect=dialect)
 59        catalog.meta["is_table"] = True
 60        catalog = normalize_identifiers(catalog, dialect=dialect)
 61
 62    def _qualify(table: exp.Table) -> None:
 63        if isinstance(table.this, exp.Identifier):
 64            if db and not table.args.get("db"):
 65                table.set("db", db.copy())
 66            if catalog and not table.args.get("catalog") and table.args.get("db"):
 67                table.set("catalog", catalog.copy())
 68
 69    if (db or catalog) and not isinstance(expression, exp.Query):
 70        with_ = expression.args.get("with_") or exp.With()
 71        cte_names = {cte.alias_or_name for cte in with_.expressions}
 72
 73        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 74            if isinstance(node, exp.Table) and node.name not in cte_names:
 75                _qualify(node)
 76
 77    def _set_alias(
 78        expression: exp.Expr,
 79        canonical_aliases: dict[str, str],
 80        target_alias: str | None = None,
 81        scope: Scope | None = None,
 82        normalize: bool = False,
 83        columns: Sequence[str | exp.Identifier] | None = None,
 84    ) -> None:
 85        alias = expression.args.get("alias") or exp.TableAlias()
 86
 87        if canonicalize_table_aliases:
 88            new_alias_name = next_alias_name()
 89            canonical_aliases[alias.name or target_alias or ""] = new_alias_name
 90        elif not alias.name:
 91            new_alias_name = target_alias or next_alias_name()
 92            if normalize and target_alias:
 93                new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name
 94        else:
 95            return
 96
 97        alias.set("this", exp.to_identifier(new_alias_name))
 98
 99        if columns:
100            alias.set("columns", [exp.to_identifier(c) for c in columns])
101
102        expression.set("alias", alias)
103
104        if scope:
105            scope.rename_source(None, new_alias_name)
106
107    for scope in traverse_scope(expression):
108        local_columns = scope.local_columns
109        canonical_aliases: dict[str, str] = {}
110
111        for query in scope.subqueries:
112            subquery = query.parent
113            if isinstance(subquery, exp.Subquery):
114                subquery.unwrap().replace(subquery)
115
116        for derived_table in scope.derived_tables:
117            unnested = derived_table.unnest()
118            if isinstance(unnested, exp.Table):
119                joins = unnested.args.get("joins")
120                unnested.set("joins", None)
121                derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
122                derived_table.this.set("joins", joins)
123
124            _set_alias(derived_table, canonical_aliases, scope=scope)
125            if pivot := seq_get(derived_table.args.get("pivots") or [], 0):
126                _set_alias(pivot, canonical_aliases)
127
128        table_aliases = {}
129
130        for name, source in scope.sources.items():
131            if isinstance(source, exp.Table):
132                # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte
133                is_real_table_source = bool(name)
134
135                if pivot := seq_get(source.args.get("pivots") or [], 0):
136                    name = source.name
137
138                table_this = source.this
139                table_alias = source.args.get("alias")
140                function_columns: Sequence[str | exp.Identifier] | None = None
141                if isinstance(table_this, exp.Func):
142                    if not table_alias:
143                        function_columns = ensure_list(
144                            dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this))
145                        )
146                    elif columns := table_alias.columns:
147                        function_columns = columns
148                    elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES:
149                        function_columns = ensure_list(source.alias_or_name)
150                        source.set("alias", None)
151                        name = ""
152
153                _set_alias(
154                    source,
155                    canonical_aliases,
156                    target_alias=name or source.name or None,
157                    normalize=True,
158                    columns=function_columns,
159                )
160
161                source_fqn = ".".join(p.name for p in source.parts)
162                had_explicit_alias = table_alias and table_alias.name
163                if not had_explicit_alias or source_fqn not in table_aliases:
164                    table_aliases[source_fqn] = source.args["alias"].this.copy()
165
166                if pivot:
167                    target_alias = source.alias if pivot.unpivot else None
168                    _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True)
169
170                    # This case corresponds to a pivoted CTE, we don't want to qualify that
171                    if isinstance(scope.sources.get(source.alias_or_name), Scope):
172                        continue
173
174                if is_real_table_source:
175                    _qualify(source)
176
177                    if on_qualify:
178                        on_qualify(source)
179            elif isinstance(source, Scope) and source.is_udtf:
180                _set_alias(udtf := source.expression, canonical_aliases)
181
182                table_alias = udtf.args["alias"]
183
184                if isinstance(udtf, exp.Values) and not table_alias.columns:
185                    column_aliases = [
186                        normalize_identifiers(i, dialect=dialect)
187                        for i in dialect.generate_values_aliases(udtf)
188                    ]
189                    table_alias.set("columns", column_aliases)
190
191        for table in scope.tables:
192            if not table.alias and isinstance(table.parent, (exp.From, exp.Join)):
193                _set_alias(table, canonical_aliases, target_alias=table.name)
194
195        for column in local_columns:
196            column_table = column.table
197
198            if column.db:
199                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
200
201                if table_alias:
202                    for p in exp.COLUMN_PARTS[1:]:
203                        column.set(p, None)
204
205                    column.set("table", table_alias.copy())
206            elif (
207                canonical_aliases
208                and column_table
209                and (canonical_table := canonical_aliases.get(column_table, "")) != column_table
210            ):
211                # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0
212                column.set("table", exp.to_identifier(canonical_table))
213
214    return expression
def qualify_tables( expression: ~E, db: str | sqlglot.expressions.core.Identifier | None = None, catalog: str | sqlglot.expressions.core.Identifier | None = None, on_qualify: Optional[Callable[[sqlglot.expressions.query.Table], NoneType]] = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, canonicalize_table_aliases: bool = False) -> ~E:
 17def qualify_tables(
 18    expression: E,
 19    db: str | exp.Identifier | None = None,
 20    catalog: str | exp.Identifier | None = None,
 21    on_qualify: t.Callable[[exp.Table], None] | None = None,
 22    dialect: DialectType = None,
 23    canonicalize_table_aliases: bool = False,
 24) -> E:
 25    """
 26    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 27    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 28
 29    Examples:
 30        >>> import sqlglot
 31        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 32        >>> qualify_tables(expression, db="db").sql()
 33        'SELECT 1 FROM db.tbl AS tbl'
 34        >>>
 35        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 36        >>> qualify_tables(expression).sql()
 37        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 38
 39    Args:
 40        expression: Expr to qualify
 41        db: Database name
 42        catalog: Catalog name
 43        on_qualify: Callback after a table has been qualified.
 44        dialect: The dialect to parse catalog and schema into.
 45        canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
 46            instead of preserving table names. Defaults to False.
 47
 48    Returns:
 49        The qualified expression.
 50    """
 51    dialect = Dialect.get_or_raise(dialect)
 52    next_alias_name = name_sequence("_")
 53
 54    if db := db or None:
 55        db = exp.parse_identifier(db, dialect=dialect)
 56        db.meta["is_table"] = True
 57        db = normalize_identifiers(db, dialect=dialect)
 58    if catalog := catalog or None:
 59        catalog = exp.parse_identifier(catalog, dialect=dialect)
 60        catalog.meta["is_table"] = True
 61        catalog = normalize_identifiers(catalog, dialect=dialect)
 62
 63    def _qualify(table: exp.Table) -> None:
 64        if isinstance(table.this, exp.Identifier):
 65            if db and not table.args.get("db"):
 66                table.set("db", db.copy())
 67            if catalog and not table.args.get("catalog") and table.args.get("db"):
 68                table.set("catalog", catalog.copy())
 69
 70    if (db or catalog) and not isinstance(expression, exp.Query):
 71        with_ = expression.args.get("with_") or exp.With()
 72        cte_names = {cte.alias_or_name for cte in with_.expressions}
 73
 74        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 75            if isinstance(node, exp.Table) and node.name not in cte_names:
 76                _qualify(node)
 77
 78    def _set_alias(
 79        expression: exp.Expr,
 80        canonical_aliases: dict[str, str],
 81        target_alias: str | None = None,
 82        scope: Scope | None = None,
 83        normalize: bool = False,
 84        columns: Sequence[str | exp.Identifier] | None = None,
 85    ) -> None:
 86        alias = expression.args.get("alias") or exp.TableAlias()
 87
 88        if canonicalize_table_aliases:
 89            new_alias_name = next_alias_name()
 90            canonical_aliases[alias.name or target_alias or ""] = new_alias_name
 91        elif not alias.name:
 92            new_alias_name = target_alias or next_alias_name()
 93            if normalize and target_alias:
 94                new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name
 95        else:
 96            return
 97
 98        alias.set("this", exp.to_identifier(new_alias_name))
 99
100        if columns:
101            alias.set("columns", [exp.to_identifier(c) for c in columns])
102
103        expression.set("alias", alias)
104
105        if scope:
106            scope.rename_source(None, new_alias_name)
107
108    for scope in traverse_scope(expression):
109        local_columns = scope.local_columns
110        canonical_aliases: dict[str, str] = {}
111
112        for query in scope.subqueries:
113            subquery = query.parent
114            if isinstance(subquery, exp.Subquery):
115                subquery.unwrap().replace(subquery)
116
117        for derived_table in scope.derived_tables:
118            unnested = derived_table.unnest()
119            if isinstance(unnested, exp.Table):
120                joins = unnested.args.get("joins")
121                unnested.set("joins", None)
122                derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
123                derived_table.this.set("joins", joins)
124
125            _set_alias(derived_table, canonical_aliases, scope=scope)
126            if pivot := seq_get(derived_table.args.get("pivots") or [], 0):
127                _set_alias(pivot, canonical_aliases)
128
129        table_aliases = {}
130
131        for name, source in scope.sources.items():
132            if isinstance(source, exp.Table):
133                # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte
134                is_real_table_source = bool(name)
135
136                if pivot := seq_get(source.args.get("pivots") or [], 0):
137                    name = source.name
138
139                table_this = source.this
140                table_alias = source.args.get("alias")
141                function_columns: Sequence[str | exp.Identifier] | None = None
142                if isinstance(table_this, exp.Func):
143                    if not table_alias:
144                        function_columns = ensure_list(
145                            dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this))
146                        )
147                    elif columns := table_alias.columns:
148                        function_columns = columns
149                    elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES:
150                        function_columns = ensure_list(source.alias_or_name)
151                        source.set("alias", None)
152                        name = ""
153
154                _set_alias(
155                    source,
156                    canonical_aliases,
157                    target_alias=name or source.name or None,
158                    normalize=True,
159                    columns=function_columns,
160                )
161
162                source_fqn = ".".join(p.name for p in source.parts)
163                had_explicit_alias = table_alias and table_alias.name
164                if not had_explicit_alias or source_fqn not in table_aliases:
165                    table_aliases[source_fqn] = source.args["alias"].this.copy()
166
167                if pivot:
168                    target_alias = source.alias if pivot.unpivot else None
169                    _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True)
170
171                    # This case corresponds to a pivoted CTE, we don't want to qualify that
172                    if isinstance(scope.sources.get(source.alias_or_name), Scope):
173                        continue
174
175                if is_real_table_source:
176                    _qualify(source)
177
178                    if on_qualify:
179                        on_qualify(source)
180            elif isinstance(source, Scope) and source.is_udtf:
181                _set_alias(udtf := source.expression, canonical_aliases)
182
183                table_alias = udtf.args["alias"]
184
185                if isinstance(udtf, exp.Values) and not table_alias.columns:
186                    column_aliases = [
187                        normalize_identifiers(i, dialect=dialect)
188                        for i in dialect.generate_values_aliases(udtf)
189                    ]
190                    table_alias.set("columns", column_aliases)
191
192        for table in scope.tables:
193            if not table.alias and isinstance(table.parent, (exp.From, exp.Join)):
194                _set_alias(table, canonical_aliases, target_alias=table.name)
195
196        for column in local_columns:
197            column_table = column.table
198
199            if column.db:
200                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
201
202                if table_alias:
203                    for p in exp.COLUMN_PARTS[1:]:
204                        column.set(p, None)
205
206                    column.set("table", table_alias.copy())
207            elif (
208                canonical_aliases
209                and column_table
210                and (canonical_table := canonical_aliases.get(column_table, "")) != column_table
211            ):
212                # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0
213                column.set("table", exp.to_identifier(canonical_table))
214
215    return expression

Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
>>> qualify_tables(expression).sql()
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
  • expression: Expr to qualify
  • db: Database name
  • catalog: Catalog name
  • on_qualify: Callback after a table has been qualified.
  • dialect: The dialect to parse catalog and schema into.
  • canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources instead of preserving table names. Defaults to False.
Returns:

The qualified expression.