sqlglot.optimizer.qualify_tables
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import DialectType 8from sqlglot.helper import csv_reader, name_sequence 9from sqlglot.optimizer.scope import Scope, traverse_scope 10from sqlglot.schema import Schema 11from sqlglot.dialects.dialect import Dialect 12 13if t.TYPE_CHECKING: 14 from sqlglot._typing import E 15 16 17def qualify_tables( 18 expression: E, 19 db: t.Optional[str | exp.Identifier] = None, 20 catalog: t.Optional[str | exp.Identifier] = None, 21 schema: t.Optional[Schema] = None, 22 infer_csv_schemas: bool = False, 23 dialect: DialectType = None, 24) -> E: 25 """ 26 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 27 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 28 29 Examples: 30 >>> import sqlglot 31 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 32 >>> qualify_tables(expression, db="db").sql() 33 'SELECT 1 FROM db.tbl AS tbl' 34 >>> 35 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 36 >>> qualify_tables(expression).sql() 37 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 38 39 Args: 40 expression: Expression to qualify 41 db: Database name 42 catalog: Catalog name 43 schema: A schema to populate 44 infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas. 45 dialect: The dialect to parse catalog and schema into. 46 47 Returns: 48 The qualified expression. 49 """ 50 next_alias_name = name_sequence("_q_") 51 db = exp.parse_identifier(db, dialect=dialect) if db else None 52 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 53 dialect = Dialect.get_or_raise(dialect) 54 55 def _qualify(table: exp.Table) -> None: 56 if isinstance(table.this, exp.Identifier): 57 if db and not table.args.get("db"): 58 table.set("db", db.copy()) 59 if catalog and not table.args.get("catalog") and table.args.get("db"): 60 table.set("catalog", catalog.copy()) 61 62 if (db or catalog) and not isinstance(expression, exp.Query): 63 with_ = expression.args.get("with") or exp.With() 64 cte_names = {cte.alias_or_name for cte in with_.expressions} 65 66 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 67 if isinstance(node, exp.Table) and node.name not in cte_names: 68 _qualify(node) 69 70 for scope in traverse_scope(expression): 71 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 72 if isinstance(derived_table, exp.Subquery): 73 unnested = derived_table.unnest() 74 if isinstance(unnested, exp.Table): 75 joins = unnested.args.pop("joins", None) 76 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 77 derived_table.this.set("joins", joins) 78 79 if not derived_table.args.get("alias"): 80 alias_ = next_alias_name() 81 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 82 scope.rename_source(None, alias_) 83 84 pivots = derived_table.args.get("pivots") 85 if pivots and not pivots[0].alias: 86 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 87 88 table_aliases = {} 89 90 for name, source in scope.sources.items(): 91 if isinstance(source, exp.Table): 92 pivots = source.args.get("pivots") 93 if not source.alias: 94 # Don't add the pivot's alias to the pivoted table, use the table's name instead 95 if pivots and pivots[0].alias == name: 96 name = source.name 97 98 # Mutates the source by attaching an alias to it 99 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 100 101 table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( 102 source.alias 103 ) 104 105 if pivots: 106 pivot = pivots[0] 107 if not pivot.alias: 108 pivot_alias = source.alias if pivot.unpivot else next_alias_name() 109 pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias))) 110 111 # This case corresponds to a pivoted CTE, we don't want to qualify that 112 if isinstance(scope.sources.get(source.alias_or_name), Scope): 113 continue 114 115 _qualify(source) 116 117 if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV): 118 with csv_reader(source.this) as reader: 119 header = next(reader) 120 columns = next(reader) 121 schema.add_table( 122 source, 123 {k: type(v).__name__ for k, v in zip(header, columns)}, 124 match_depth=False, 125 ) 126 elif isinstance(source, Scope) and source.is_udtf: 127 udtf = source.expression 128 table_alias = udtf.args.get("alias") or exp.TableAlias( 129 this=exp.to_identifier(next_alias_name()) 130 ) 131 if ( 132 isinstance(udtf, exp.Unnest) 133 and dialect.UNNEST_COLUMN_ONLY 134 and not table_alias.columns 135 ): 136 table_alias.set("columns", [table_alias.this.copy()]) 137 table_alias.set("column_only", True) 138 139 udtf.set("alias", table_alias) 140 141 if not table_alias.name: 142 table_alias.set("this", exp.to_identifier(next_alias_name())) 143 if isinstance(udtf, exp.Values) and not table_alias.columns: 144 column_aliases = dialect.generate_values_aliases(udtf) 145 table_alias.set("columns", column_aliases) 146 else: 147 for node in scope.walk(): 148 if ( 149 isinstance(node, exp.Table) 150 and not node.alias 151 and isinstance(node.parent, (exp.From, exp.Join)) 152 ): 153 # Mutates the table by attaching an alias to it 154 alias(node, node.name, copy=False, table=True) 155 156 for column in scope.columns: 157 if column.db: 158 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 159 160 if table_alias: 161 for p in exp.COLUMN_PARTS[1:]: 162 column.set(p, None) 163 164 column.set("table", table_alias.copy()) 165 166 return expression
def
qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, infer_csv_schemas: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
18def qualify_tables( 19 expression: E, 20 db: t.Optional[str | exp.Identifier] = None, 21 catalog: t.Optional[str | exp.Identifier] = None, 22 schema: t.Optional[Schema] = None, 23 infer_csv_schemas: bool = False, 24 dialect: DialectType = None, 25) -> E: 26 """ 27 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 28 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 29 30 Examples: 31 >>> import sqlglot 32 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 33 >>> qualify_tables(expression, db="db").sql() 34 'SELECT 1 FROM db.tbl AS tbl' 35 >>> 36 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 37 >>> qualify_tables(expression).sql() 38 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 39 40 Args: 41 expression: Expression to qualify 42 db: Database name 43 catalog: Catalog name 44 schema: A schema to populate 45 infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas. 46 dialect: The dialect to parse catalog and schema into. 47 48 Returns: 49 The qualified expression. 50 """ 51 next_alias_name = name_sequence("_q_") 52 db = exp.parse_identifier(db, dialect=dialect) if db else None 53 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 54 dialect = Dialect.get_or_raise(dialect) 55 56 def _qualify(table: exp.Table) -> None: 57 if isinstance(table.this, exp.Identifier): 58 if db and not table.args.get("db"): 59 table.set("db", db.copy()) 60 if catalog and not table.args.get("catalog") and table.args.get("db"): 61 table.set("catalog", catalog.copy()) 62 63 if (db or catalog) and not isinstance(expression, exp.Query): 64 with_ = expression.args.get("with") or exp.With() 65 cte_names = {cte.alias_or_name for cte in with_.expressions} 66 67 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 68 if isinstance(node, exp.Table) and node.name not in cte_names: 69 _qualify(node) 70 71 for scope in traverse_scope(expression): 72 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 73 if isinstance(derived_table, exp.Subquery): 74 unnested = derived_table.unnest() 75 if isinstance(unnested, exp.Table): 76 joins = unnested.args.pop("joins", None) 77 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 78 derived_table.this.set("joins", joins) 79 80 if not derived_table.args.get("alias"): 81 alias_ = next_alias_name() 82 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 83 scope.rename_source(None, alias_) 84 85 pivots = derived_table.args.get("pivots") 86 if pivots and not pivots[0].alias: 87 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 88 89 table_aliases = {} 90 91 for name, source in scope.sources.items(): 92 if isinstance(source, exp.Table): 93 pivots = source.args.get("pivots") 94 if not source.alias: 95 # Don't add the pivot's alias to the pivoted table, use the table's name instead 96 if pivots and pivots[0].alias == name: 97 name = source.name 98 99 # Mutates the source by attaching an alias to it 100 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 101 102 table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( 103 source.alias 104 ) 105 106 if pivots: 107 pivot = pivots[0] 108 if not pivot.alias: 109 pivot_alias = source.alias if pivot.unpivot else next_alias_name() 110 pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias))) 111 112 # This case corresponds to a pivoted CTE, we don't want to qualify that 113 if isinstance(scope.sources.get(source.alias_or_name), Scope): 114 continue 115 116 _qualify(source) 117 118 if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV): 119 with csv_reader(source.this) as reader: 120 header = next(reader) 121 columns = next(reader) 122 schema.add_table( 123 source, 124 {k: type(v).__name__ for k, v in zip(header, columns)}, 125 match_depth=False, 126 ) 127 elif isinstance(source, Scope) and source.is_udtf: 128 udtf = source.expression 129 table_alias = udtf.args.get("alias") or exp.TableAlias( 130 this=exp.to_identifier(next_alias_name()) 131 ) 132 if ( 133 isinstance(udtf, exp.Unnest) 134 and dialect.UNNEST_COLUMN_ONLY 135 and not table_alias.columns 136 ): 137 table_alias.set("columns", [table_alias.this.copy()]) 138 table_alias.set("column_only", True) 139 140 udtf.set("alias", table_alias) 141 142 if not table_alias.name: 143 table_alias.set("this", exp.to_identifier(next_alias_name())) 144 if isinstance(udtf, exp.Values) and not table_alias.columns: 145 column_aliases = dialect.generate_values_aliases(udtf) 146 table_alias.set("columns", column_aliases) 147 else: 148 for node in scope.walk(): 149 if ( 150 isinstance(node, exp.Table) 151 and not node.alias 152 and isinstance(node.parent, (exp.From, exp.Join)) 153 ): 154 # Mutates the table by attaching an alias to it 155 alias(node, node.name, copy=False, table=True) 156 157 for column in scope.columns: 158 if column.db: 159 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 160 161 if table_alias: 162 for p in exp.COLUMN_PARTS[1:]: 163 column.set(p, None) 164 165 column.set("table", table_alias.copy()) 166 167 return expression
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
- expression: Expression to qualify
- db: Database name
- catalog: Catalog name
- schema: A schema to populate
- infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
- dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.