sqlglot.optimizer.qualify_tables
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot.dialects.dialect import Dialect, DialectType 7from sqlglot.helper import name_sequence, seq_get, ensure_list 8from sqlglot.optimizer.normalize_identifiers import normalize_identifiers 9from sqlglot.optimizer.scope import Scope, traverse_scope 10 11if t.TYPE_CHECKING: 12 from sqlglot._typing import E 13 14 15def qualify_tables( 16 expression: E, 17 db: t.Optional[str | exp.Identifier] = None, 18 catalog: t.Optional[str | exp.Identifier] = None, 19 on_qualify: t.Optional[t.Callable[[exp.Table], None]] = None, 20 dialect: DialectType = None, 21 canonicalize_table_aliases: bool = False, 22) -> E: 23 """ 24 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 25 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 26 27 Examples: 28 >>> import sqlglot 29 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 30 >>> qualify_tables(expression, db="db").sql() 31 'SELECT 1 FROM db.tbl AS tbl' 32 >>> 33 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 34 >>> qualify_tables(expression).sql() 35 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 36 37 Args: 38 expression: Expression to qualify 39 db: Database name 40 catalog: Catalog name 41 on_qualify: Callback after a table has been qualified. 42 dialect: The dialect to parse catalog and schema into. 43 canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources 44 instead of preserving table names. Defaults to False. 45 46 Returns: 47 The qualified expression. 48 """ 49 dialect = Dialect.get_or_raise(dialect) 50 next_alias_name = name_sequence("_") 51 52 if db := db or None: 53 db = exp.parse_identifier(db, dialect=dialect) 54 db.meta["is_table"] = True 55 db = normalize_identifiers(db, dialect=dialect) 56 if catalog := catalog or None: 57 catalog = exp.parse_identifier(catalog, dialect=dialect) 58 catalog.meta["is_table"] = True 59 catalog = normalize_identifiers(catalog, dialect=dialect) 60 61 def _qualify(table: exp.Table) -> None: 62 if isinstance(table.this, exp.Identifier): 63 if db and not table.args.get("db"): 64 table.set("db", db.copy()) 65 if catalog and not table.args.get("catalog") and table.args.get("db"): 66 table.set("catalog", catalog.copy()) 67 68 if (db or catalog) and not isinstance(expression, exp.Query): 69 with_ = expression.args.get("with_") or exp.With() 70 cte_names = {cte.alias_or_name for cte in with_.expressions} 71 72 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 73 if isinstance(node, exp.Table) and node.name not in cte_names: 74 _qualify(node) 75 76 def _set_alias( 77 expression: exp.Expression, 78 canonical_aliases: t.Dict[str, str], 79 target_alias: t.Optional[str] = None, 80 scope: t.Optional[Scope] = None, 81 normalize: bool = False, 82 columns: t.Optional[t.List[t.Union[str, exp.Identifier]]] = None, 83 ) -> None: 84 alias = expression.args.get("alias") or exp.TableAlias() 85 86 if canonicalize_table_aliases: 87 new_alias_name = next_alias_name() 88 canonical_aliases[alias.name or target_alias or ""] = new_alias_name 89 elif not alias.name: 90 new_alias_name = target_alias or next_alias_name() 91 if normalize and target_alias: 92 new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name 93 else: 94 return 95 96 alias.set("this", exp.to_identifier(new_alias_name)) 97 98 if columns: 99 alias.set("columns", [exp.to_identifier(c) for c in columns]) 100 101 expression.set("alias", alias) 102 103 if scope: 104 scope.rename_source(None, new_alias_name) 105 106 for scope in traverse_scope(expression): 107 local_columns = scope.local_columns 108 canonical_aliases: t.Dict[str, str] = {} 109 110 for query in scope.subqueries: 111 subquery = query.parent 112 if isinstance(subquery, exp.Subquery): 113 subquery.unwrap().replace(subquery) 114 115 for derived_table in scope.derived_tables: 116 unnested = derived_table.unnest() 117 if isinstance(unnested, exp.Table): 118 joins = unnested.args.get("joins") 119 unnested.set("joins", None) 120 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 121 derived_table.this.set("joins", joins) 122 123 _set_alias(derived_table, canonical_aliases, scope=scope) 124 if pivot := seq_get(derived_table.args.get("pivots") or [], 0): 125 _set_alias(pivot, canonical_aliases) 126 127 table_aliases = {} 128 129 for name, source in scope.sources.items(): 130 if isinstance(source, exp.Table): 131 # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte 132 is_real_table_source = bool(name) 133 134 if pivot := seq_get(source.args.get("pivots") or [], 0): 135 name = source.name 136 137 table_this = source.this 138 table_alias = source.args.get("alias") 139 function_columns: t.List[t.Union[str, exp.Identifier]] = [] 140 if isinstance(table_this, exp.Func): 141 if not table_alias: 142 function_columns = ensure_list( 143 dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this)) 144 ) 145 elif columns := table_alias.columns: 146 function_columns = columns 147 elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES: 148 function_columns = ensure_list(source.alias_or_name) 149 source.set("alias", None) 150 name = None 151 152 _set_alias( 153 source, 154 canonical_aliases, 155 target_alias=name or source.name or None, 156 normalize=True, 157 columns=function_columns, 158 ) 159 160 source_fqn = ".".join(p.name for p in source.parts) 161 table_aliases[source_fqn] = source.args["alias"].this.copy() 162 163 if pivot: 164 target_alias = source.alias if pivot.unpivot else None 165 _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True) 166 167 # This case corresponds to a pivoted CTE, we don't want to qualify that 168 if isinstance(scope.sources.get(source.alias_or_name), Scope): 169 continue 170 171 if is_real_table_source: 172 _qualify(source) 173 174 if on_qualify: 175 on_qualify(source) 176 elif isinstance(source, Scope) and source.is_udtf: 177 _set_alias(udtf := source.expression, canonical_aliases) 178 179 table_alias = udtf.args["alias"] 180 181 if isinstance(udtf, exp.Values) and not table_alias.columns: 182 column_aliases = [ 183 normalize_identifiers(i, dialect=dialect) 184 for i in dialect.generate_values_aliases(udtf) 185 ] 186 table_alias.set("columns", column_aliases) 187 188 for table in scope.tables: 189 if not table.alias and isinstance(table.parent, (exp.From, exp.Join)): 190 _set_alias(table, canonical_aliases, target_alias=table.name) 191 192 for column in local_columns: 193 table = column.table 194 195 if column.db: 196 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 197 198 if table_alias: 199 for p in exp.COLUMN_PARTS[1:]: 200 column.set(p, None) 201 202 column.set("table", table_alias.copy()) 203 elif ( 204 canonical_aliases 205 and table 206 and (canonical_table := canonical_aliases.get(table, "")) != column.table 207 ): 208 # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0 209 column.set("table", exp.to_identifier(canonical_table)) 210 211 return expression
def
qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, on_qualify: Optional[Callable[[sqlglot.expressions.Table], NoneType]] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, canonicalize_table_aliases: bool = False) -> ~E:
16def qualify_tables( 17 expression: E, 18 db: t.Optional[str | exp.Identifier] = None, 19 catalog: t.Optional[str | exp.Identifier] = None, 20 on_qualify: t.Optional[t.Callable[[exp.Table], None]] = None, 21 dialect: DialectType = None, 22 canonicalize_table_aliases: bool = False, 23) -> E: 24 """ 25 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 26 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 27 28 Examples: 29 >>> import sqlglot 30 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 31 >>> qualify_tables(expression, db="db").sql() 32 'SELECT 1 FROM db.tbl AS tbl' 33 >>> 34 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 35 >>> qualify_tables(expression).sql() 36 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 37 38 Args: 39 expression: Expression to qualify 40 db: Database name 41 catalog: Catalog name 42 on_qualify: Callback after a table has been qualified. 43 dialect: The dialect to parse catalog and schema into. 44 canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources 45 instead of preserving table names. Defaults to False. 46 47 Returns: 48 The qualified expression. 49 """ 50 dialect = Dialect.get_or_raise(dialect) 51 next_alias_name = name_sequence("_") 52 53 if db := db or None: 54 db = exp.parse_identifier(db, dialect=dialect) 55 db.meta["is_table"] = True 56 db = normalize_identifiers(db, dialect=dialect) 57 if catalog := catalog or None: 58 catalog = exp.parse_identifier(catalog, dialect=dialect) 59 catalog.meta["is_table"] = True 60 catalog = normalize_identifiers(catalog, dialect=dialect) 61 62 def _qualify(table: exp.Table) -> None: 63 if isinstance(table.this, exp.Identifier): 64 if db and not table.args.get("db"): 65 table.set("db", db.copy()) 66 if catalog and not table.args.get("catalog") and table.args.get("db"): 67 table.set("catalog", catalog.copy()) 68 69 if (db or catalog) and not isinstance(expression, exp.Query): 70 with_ = expression.args.get("with_") or exp.With() 71 cte_names = {cte.alias_or_name for cte in with_.expressions} 72 73 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 74 if isinstance(node, exp.Table) and node.name not in cte_names: 75 _qualify(node) 76 77 def _set_alias( 78 expression: exp.Expression, 79 canonical_aliases: t.Dict[str, str], 80 target_alias: t.Optional[str] = None, 81 scope: t.Optional[Scope] = None, 82 normalize: bool = False, 83 columns: t.Optional[t.List[t.Union[str, exp.Identifier]]] = None, 84 ) -> None: 85 alias = expression.args.get("alias") or exp.TableAlias() 86 87 if canonicalize_table_aliases: 88 new_alias_name = next_alias_name() 89 canonical_aliases[alias.name or target_alias or ""] = new_alias_name 90 elif not alias.name: 91 new_alias_name = target_alias or next_alias_name() 92 if normalize and target_alias: 93 new_alias_name = normalize_identifiers(new_alias_name, dialect=dialect).name 94 else: 95 return 96 97 alias.set("this", exp.to_identifier(new_alias_name)) 98 99 if columns: 100 alias.set("columns", [exp.to_identifier(c) for c in columns]) 101 102 expression.set("alias", alias) 103 104 if scope: 105 scope.rename_source(None, new_alias_name) 106 107 for scope in traverse_scope(expression): 108 local_columns = scope.local_columns 109 canonical_aliases: t.Dict[str, str] = {} 110 111 for query in scope.subqueries: 112 subquery = query.parent 113 if isinstance(subquery, exp.Subquery): 114 subquery.unwrap().replace(subquery) 115 116 for derived_table in scope.derived_tables: 117 unnested = derived_table.unnest() 118 if isinstance(unnested, exp.Table): 119 joins = unnested.args.get("joins") 120 unnested.set("joins", None) 121 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 122 derived_table.this.set("joins", joins) 123 124 _set_alias(derived_table, canonical_aliases, scope=scope) 125 if pivot := seq_get(derived_table.args.get("pivots") or [], 0): 126 _set_alias(pivot, canonical_aliases) 127 128 table_aliases = {} 129 130 for name, source in scope.sources.items(): 131 if isinstance(source, exp.Table): 132 # When the name is empty, it means that we have a non-table source, e.g. a pivoted cte 133 is_real_table_source = bool(name) 134 135 if pivot := seq_get(source.args.get("pivots") or [], 0): 136 name = source.name 137 138 table_this = source.this 139 table_alias = source.args.get("alias") 140 function_columns: t.List[t.Union[str, exp.Identifier]] = [] 141 if isinstance(table_this, exp.Func): 142 if not table_alias: 143 function_columns = ensure_list( 144 dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES.get(type(table_this)) 145 ) 146 elif columns := table_alias.columns: 147 function_columns = columns 148 elif type(table_this) in dialect.DEFAULT_FUNCTIONS_COLUMN_NAMES: 149 function_columns = ensure_list(source.alias_or_name) 150 source.set("alias", None) 151 name = None 152 153 _set_alias( 154 source, 155 canonical_aliases, 156 target_alias=name or source.name or None, 157 normalize=True, 158 columns=function_columns, 159 ) 160 161 source_fqn = ".".join(p.name for p in source.parts) 162 table_aliases[source_fqn] = source.args["alias"].this.copy() 163 164 if pivot: 165 target_alias = source.alias if pivot.unpivot else None 166 _set_alias(pivot, canonical_aliases, target_alias=target_alias, normalize=True) 167 168 # This case corresponds to a pivoted CTE, we don't want to qualify that 169 if isinstance(scope.sources.get(source.alias_or_name), Scope): 170 continue 171 172 if is_real_table_source: 173 _qualify(source) 174 175 if on_qualify: 176 on_qualify(source) 177 elif isinstance(source, Scope) and source.is_udtf: 178 _set_alias(udtf := source.expression, canonical_aliases) 179 180 table_alias = udtf.args["alias"] 181 182 if isinstance(udtf, exp.Values) and not table_alias.columns: 183 column_aliases = [ 184 normalize_identifiers(i, dialect=dialect) 185 for i in dialect.generate_values_aliases(udtf) 186 ] 187 table_alias.set("columns", column_aliases) 188 189 for table in scope.tables: 190 if not table.alias and isinstance(table.parent, (exp.From, exp.Join)): 191 _set_alias(table, canonical_aliases, target_alias=table.name) 192 193 for column in local_columns: 194 table = column.table 195 196 if column.db: 197 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 198 199 if table_alias: 200 for p in exp.COLUMN_PARTS[1:]: 201 column.set(p, None) 202 203 column.set("table", table_alias.copy()) 204 elif ( 205 canonical_aliases 206 and table 207 and (canonical_table := canonical_aliases.get(table, "")) != column.table 208 ): 209 # Amend existing aliases, e.g. t.c -> _0.c if t is aliased to _0 210 column.set("table", exp.to_identifier(canonical_table)) 211 212 return expression
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
- expression: Expression to qualify
- db: Database name
- catalog: Catalog name
- on_qualify: Callback after a table has been qualified.
- dialect: The dialect to parse catalog and schema into.
- canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources instead of preserving table names. Defaults to False.
Returns:
The qualified expression.