sqlglot.optimizer.qualify_tables
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot.dialects.dialect import Dialect, DialectType 7from sqlglot.helper import name_sequence, seq_get, ensure_list 8from sqlglot.optimizer.normalize_identifiers import normalize_identifiers 9from sqlglot.optimizer.scope import Scope, traverse_scope 10 11if t.TYPE_CHECKING: 12 from sqlglot._typing import E 13 from collections.abc import Sequence 14 15 16def qualify_tables( 17 expression: E, 18 db: str | exp.Identifier | None = None, 19 catalog: str | exp.Identifier | None = None, 20 on_qualify: t.Callable[[exp.Table], None] | None = None, 21 dialect: DialectType = None, 22 canonicalize_table_aliases: bool = False, 23) -> E: 24 """ 25 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 26 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 27 28 Examples: 29 >>> import sqlglot 30 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 31 >>> qualify_tables(expression, db="db").sql() 32 'SELECT 1 FROM db.tbl AS tbl' 33 >>> 34 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 35 >>> qualify_tables(expression).sql() 36 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 37 38 Args: 39 expression: Expr to qualify 40 db: Database name 41 catalog: Catalog name 42 on_qualify: Callback after a table has been qualified. 43 dialect: The dialect to parse catalog and schema into. 44 canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources 45 instead of preserving table names. Defaults to False. 46 47 Returns: 48 The qualified expression. 49 """ 50 dialect = Dialect.get_or_raise(dialect) 51 next_alias_name = name_sequence("_") 52 53 if db := db or None: 54 db = exp.parse_identifier(db, dialect=dialect) 55 db.meta["is_table"] = True 56 db = normalize_identifiers(db, dialect=dialect) 57 if catalog := catalog or None: 58 catalog = exp.parse_identifier(catalog, dialect=dialect) 59 catalog.meta["is_table"] = True 60 catalog = normalize_identifiers(catalog, dialect=dialect) 61 62 def _qualify(table: exp.Table) -> None: 63 if isinstance(table.this, exp.Identifier): 64 if db and not table.args.get("db"): 65 table.set("db", db.copy()) 66 if catalog and not table.args.get("catalog") and table.args.get("db"): 67 table.set("catalog", catalog.copy()) 68 69 if (db or catalog) and not isinstance(expression, exp.Query): 70 with_ = expression.args.get("with_") or exp.With() 71 cte_names = {cte.alias_or_name for cte in with_.expressions} 72 73 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 74 if isinstance(node, exp.Table) and node.name not in cte_names: 75 _qualify(node) 76 77 def _set_alias( 78 expression: exp.Expr, 79 canonical_aliases: dict[str, str], 80 target_alias: str | None = None, 81 scope: Scope | None = None, 82 normalize: bool = False, 83 columns: Sequence[str | exp.Identifier] | None = None, 84 ) -> None: 85 alias = expression.args.get("alias") or exp.TableAlias() 86 87 if canonicalize_table_aliases: 88 new_alias_name = next_alias_name() 89 canonical_aliases[alias.name or target_alias or ""] = new_alias_name 90 elif not alias.name: 91 new_alias_name = target_alias or next_alias_name() 92 if normalize and target_alias: 93 new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name 94 else: 95 return 96 97 alias.set("this", exp.to_identifier(new_alias_name)) 98 99 if columns: 100 alias.set("columns", [exp.to_identifier(c) for c in columns]) 101 102 expression.set("alias", alias) 103 104 if scope: 105 scope.rename_source(None, new_alias_name) 106 107 for scope in traverse_scope(expression): 108 local_columns = scope.local_columns 109 canonical_aliases: dict[str, str] = {} 110 111 for query in scope.subqueries: 112 subquery = query.parent 113 if isinstance(subquery, exp.Subquery): 114 subquery.unwrap().replace(subquery) 115 116 for derived_table in scope.derived_tables: 117 unnested = derived_table.unnest() 118 if isinstance(unnested, exp.Table): 119 joins = unnested.args.get("joins") 120 unnested.set("joins", None) 121 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 122 derived_table.this.set("joins", joins) 123 124 _set_alias(derived_table, canonical_aliases, scope=scope) 125 if pivot := seq_get(derived_table.args.get("pivots") or [], 0): 126 _set_alias(pivot, canonical_aliases) 127 128 table_aliases = {} 129 130 for name, source in scope.sources.items(): 131 if isinstance(source, exp.Table): 132 # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte 133 is_real_table_source = bool(name) 134 135 if pivot := seq_get(source.args.get("pivots") or [], 0): 136 name = source.name 137 138 table_this = source.this 139 table_alias = source.args.get("alias") 140 function_columns: Sequence[str | exp.Identifier] | None = None 141 if isinstance(table_this, exp.Func): 142 if not table_alias: 143 function_columns = ensure_list( 144 dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this)) 145 ) 146 elif columns := table_alias.columns: 147 function_columns = columns 148 elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES: 149 function_columns = ensure_list(source.alias_or_name) 150 source.set("alias", None) 151 name = "" 152 153 _set_alias( 154 source, 155 canonical_aliases, 156 target_alias=name or source.name or None, 157 normalize=True, 158 columns=function_columns, 159 ) 160 161 source_fqn = ".".join(p.name for p in source.parts) 162 had_explicit_alias = table_alias and table_alias.name 163 if not had_explicit_alias or source_fqn not in table_aliases: 164 table_aliases[source_fqn] = source.args["alias"].this.copy() 165 166 if pivot: 167 target_alias = source.alias if pivot.unpivot else None 168 _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True) 169 170 # This case corresponds to a pivoted CTE, we don't want to qualify that 171 if isinstance(scope.sources.get(source.alias_or_name), Scope): 172 continue 173 174 if is_real_table_source: 175 _qualify(source) 176 177 if on_qualify: 178 on_qualify(source) 179 elif isinstance(source, Scope) and source.is_udtf: 180 _set_alias(udtf := source.expression, canonical_aliases) 181 182 table_alias = udtf.args["alias"] 183 184 if isinstance(udtf, exp.Values) and not table_alias.columns: 185 column_aliases = [ 186 normalize_identifiers(i, dialect=dialect) 187 for i in dialect.generate_values_aliases(udtf) 188 ] 189 table_alias.set("columns", column_aliases) 190 191 for table in scope.tables: 192 if not table.alias and isinstance(table.parent, (exp.From, exp.Join)): 193 _set_alias(table, canonical_aliases, target_alias=table.name) 194 195 for column in local_columns: 196 column_table = column.table 197 198 if column.db: 199 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 200 201 if table_alias: 202 for p in exp.COLUMN_PARTS[1:]: 203 column.set(p, None) 204 205 column.set("table", table_alias.copy()) 206 elif ( 207 canonical_aliases 208 and column_table 209 and (canonical_table := canonical_aliases.get(column_table, "")) != column_table 210 ): 211 # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0 212 column.set("table", exp.to_identifier(canonical_table)) 213 214 return expression
def
qualify_tables( expression: ~E, db: str | sqlglot.expressions.core.Identifier | None = None, catalog: str | sqlglot.expressions.core.Identifier | None = None, on_qualify: Optional[Callable[[sqlglot.expressions.query.Table], NoneType]] = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, canonicalize_table_aliases: bool = False) -> ~E:
17def qualify_tables( 18 expression: E, 19 db: str | exp.Identifier | None = None, 20 catalog: str | exp.Identifier | None = None, 21 on_qualify: t.Callable[[exp.Table], None] | None = None, 22 dialect: DialectType = None, 23 canonicalize_table_aliases: bool = False, 24) -> E: 25 """ 26 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 27 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 28 29 Examples: 30 >>> import sqlglot 31 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 32 >>> qualify_tables(expression, db="db").sql() 33 'SELECT 1 FROM db.tbl AS tbl' 34 >>> 35 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 36 >>> qualify_tables(expression).sql() 37 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 38 39 Args: 40 expression: Expr to qualify 41 db: Database name 42 catalog: Catalog name 43 on_qualify: Callback after a table has been qualified. 44 dialect: The dialect to parse catalog and schema into. 45 canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources 46 instead of preserving table names. Defaults to False. 47 48 Returns: 49 The qualified expression. 50 """ 51 dialect = Dialect.get_or_raise(dialect) 52 next_alias_name = name_sequence("_") 53 54 if db := db or None: 55 db = exp.parse_identifier(db, dialect=dialect) 56 db.meta["is_table"] = True 57 db = normalize_identifiers(db, dialect=dialect) 58 if catalog := catalog or None: 59 catalog = exp.parse_identifier(catalog, dialect=dialect) 60 catalog.meta["is_table"] = True 61 catalog = normalize_identifiers(catalog, dialect=dialect) 62 63 def _qualify(table: exp.Table) -> None: 64 if isinstance(table.this, exp.Identifier): 65 if db and not table.args.get("db"): 66 table.set("db", db.copy()) 67 if catalog and not table.args.get("catalog") and table.args.get("db"): 68 table.set("catalog", catalog.copy()) 69 70 if (db or catalog) and not isinstance(expression, exp.Query): 71 with_ = expression.args.get("with_") or exp.With() 72 cte_names = {cte.alias_or_name for cte in with_.expressions} 73 74 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 75 if isinstance(node, exp.Table) and node.name not in cte_names: 76 _qualify(node) 77 78 def _set_alias( 79 expression: exp.Expr, 80 canonical_aliases: dict[str, str], 81 target_alias: str | None = None, 82 scope: Scope | None = None, 83 normalize: bool = False, 84 columns: Sequence[str | exp.Identifier] | None = None, 85 ) -> None: 86 alias = expression.args.get("alias") or exp.TableAlias() 87 88 if canonicalize_table_aliases: 89 new_alias_name = next_alias_name() 90 canonical_aliases[alias.name or target_alias or ""] = new_alias_name 91 elif not alias.name: 92 new_alias_name = target_alias or next_alias_name() 93 if normalize and target_alias: 94 new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name 95 else: 96 return 97 98 alias.set("this", exp.to_identifier(new_alias_name)) 99 100 if columns: 101 alias.set("columns", [exp.to_identifier(c) for c in columns]) 102 103 expression.set("alias", alias) 104 105 if scope: 106 scope.rename_source(None, new_alias_name) 107 108 for scope in traverse_scope(expression): 109 local_columns = scope.local_columns 110 canonical_aliases: dict[str, str] = {} 111 112 for query in scope.subqueries: 113 subquery = query.parent 114 if isinstance(subquery, exp.Subquery): 115 subquery.unwrap().replace(subquery) 116 117 for derived_table in scope.derived_tables: 118 unnested = derived_table.unnest() 119 if isinstance(unnested, exp.Table): 120 joins = unnested.args.get("joins") 121 unnested.set("joins", None) 122 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 123 derived_table.this.set("joins", joins) 124 125 _set_alias(derived_table, canonical_aliases, scope=scope) 126 if pivot := seq_get(derived_table.args.get("pivots") or [], 0): 127 _set_alias(pivot, canonical_aliases) 128 129 table_aliases = {} 130 131 for name, source in scope.sources.items(): 132 if isinstance(source, exp.Table): 133 # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte 134 is_real_table_source = bool(name) 135 136 if pivot := seq_get(source.args.get("pivots") or [], 0): 137 name = source.name 138 139 table_this = source.this 140 table_alias = source.args.get("alias") 141 function_columns: Sequence[str | exp.Identifier] | None = None 142 if isinstance(table_this, exp.Func): 143 if not table_alias: 144 function_columns = ensure_list( 145 dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this)) 146 ) 147 elif columns := table_alias.columns: 148 function_columns = columns 149 elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES: 150 function_columns = ensure_list(source.alias_or_name) 151 source.set("alias", None) 152 name = "" 153 154 _set_alias( 155 source, 156 canonical_aliases, 157 target_alias=name or source.name or None, 158 normalize=True, 159 columns=function_columns, 160 ) 161 162 source_fqn = ".".join(p.name for p in source.parts) 163 had_explicit_alias = table_alias and table_alias.name 164 if not had_explicit_alias or source_fqn not in table_aliases: 165 table_aliases[source_fqn] = source.args["alias"].this.copy() 166 167 if pivot: 168 target_alias = source.alias if pivot.unpivot else None 169 _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True) 170 171 # This case corresponds to a pivoted CTE, we don't want to qualify that 172 if isinstance(scope.sources.get(source.alias_or_name), Scope): 173 continue 174 175 if is_real_table_source: 176 _qualify(source) 177 178 if on_qualify: 179 on_qualify(source) 180 elif isinstance(source, Scope) and source.is_udtf: 181 _set_alias(udtf := source.expression, canonical_aliases) 182 183 table_alias = udtf.args["alias"] 184 185 if isinstance(udtf, exp.Values) and not table_alias.columns: 186 column_aliases = [ 187 normalize_identifiers(i, dialect=dialect) 188 for i in dialect.generate_values_aliases(udtf) 189 ] 190 table_alias.set("columns", column_aliases) 191 192 for table in scope.tables: 193 if not table.alias and isinstance(table.parent, (exp.From, exp.Join)): 194 _set_alias(table, canonical_aliases, target_alias=table.name) 195 196 for column in local_columns: 197 column_table = column.table 198 199 if column.db: 200 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 201 202 if table_alias: 203 for p in exp.COLUMN_PARTS[1:]: 204 column.set(p, None) 205 206 column.set("table", table_alias.copy()) 207 elif ( 208 canonical_aliases 209 and column_table 210 and (canonical_table := canonical_aliases.get(column_table, "")) != column_table 211 ): 212 # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0 213 column.set("table", exp.to_identifier(canonical_table)) 214 215 return expression
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
- expression: Expr to qualify
- db: Database name
- catalog: Catalog name
- on_qualify: Callback after a table has been qualified.
- dialect: The dialect to parse catalog and schema into.
- canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources instead of preserving table names. Defaults to False.
Returns:
The qualified expression.