sqlglot.optimizer.qualify_tables
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import DialectType 8from sqlglot.helper import csv_reader, name_sequence 9from sqlglot.optimizer.scope import Scope, traverse_scope 10from sqlglot.schema import Schema 11 12if t.TYPE_CHECKING: 13 from sqlglot._typing import E 14 15 16def qualify_tables( 17 expression: E, 18 db: t.Optional[str | exp.Identifier] = None, 19 catalog: t.Optional[str | exp.Identifier] = None, 20 schema: t.Optional[Schema] = None, 21 infer_csv_schemas: bool = False, 22 dialect: DialectType = None, 23) -> E: 24 """ 25 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 26 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 27 28 Examples: 29 >>> import sqlglot 30 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 31 >>> qualify_tables(expression, db="db").sql() 32 'SELECT 1 FROM db.tbl AS tbl' 33 >>> 34 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 35 >>> qualify_tables(expression).sql() 36 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 37 38 Args: 39 expression: Expression to qualify 40 db: Database name 41 catalog: Catalog name 42 schema: A schema to populate 43 infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas. 44 dialect: The dialect to parse catalog and schema into. 45 46 Returns: 47 The qualified expression. 48 """ 49 next_alias_name = name_sequence("_q_") 50 db = exp.parse_identifier(db, dialect=dialect) if db else None 51 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 52 53 def _qualify(table: exp.Table) -> None: 54 if isinstance(table.this, exp.Identifier): 55 if not table.args.get("db"): 56 table.set("db", db) 57 if not table.args.get("catalog") and table.args.get("db"): 58 table.set("catalog", catalog) 59 60 if (db or catalog) and not isinstance(expression, exp.Query): 61 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 62 if isinstance(node, exp.Table): 63 _qualify(node) 64 65 for scope in traverse_scope(expression): 66 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 67 if isinstance(derived_table, exp.Subquery): 68 unnested = derived_table.unnest() 69 if isinstance(unnested, exp.Table): 70 joins = unnested.args.pop("joins", None) 71 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 72 derived_table.this.set("joins", joins) 73 74 if not derived_table.args.get("alias"): 75 alias_ = next_alias_name() 76 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 77 scope.rename_source(None, alias_) 78 79 pivots = derived_table.args.get("pivots") 80 if pivots and not pivots[0].alias: 81 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 82 83 table_aliases = {} 84 85 for name, source in scope.sources.items(): 86 if isinstance(source, exp.Table): 87 pivots = source.args.get("pivots") 88 if not source.alias: 89 # Don't add the pivot's alias to the pivoted table, use the table's name instead 90 if pivots and pivots[0].alias == name: 91 name = source.name 92 93 # Mutates the source by attaching an alias to it 94 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 95 96 table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( 97 source.alias 98 ) 99 100 if pivots: 101 if not pivots[0].alias: 102 pivot_alias = next_alias_name() 103 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias))) 104 105 # This case corresponds to a pivoted CTE, we don't want to qualify that 106 if isinstance(scope.sources.get(source.alias_or_name), Scope): 107 continue 108 109 _qualify(source) 110 111 if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV): 112 with csv_reader(source.this) as reader: 113 header = next(reader) 114 columns = next(reader) 115 schema.add_table( 116 source, 117 {k: type(v).__name__ for k, v in zip(header, columns)}, 118 match_depth=False, 119 ) 120 elif isinstance(source, Scope) and source.is_udtf: 121 udtf = source.expression 122 table_alias = udtf.args.get("alias") or exp.TableAlias( 123 this=exp.to_identifier(next_alias_name()) 124 ) 125 udtf.set("alias", table_alias) 126 127 if not table_alias.name: 128 table_alias.set("this", exp.to_identifier(next_alias_name())) 129 if isinstance(udtf, exp.Values) and not table_alias.columns: 130 for i, e in enumerate(udtf.expressions[0].expressions): 131 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 132 else: 133 for node in scope.walk(): 134 if ( 135 isinstance(node, exp.Table) 136 and not node.alias 137 and isinstance(node.parent, (exp.From, exp.Join)) 138 ): 139 # Mutates the table by attaching an alias to it 140 alias(node, node.name, copy=False, table=True) 141 142 for column in scope.columns: 143 if column.db: 144 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 145 146 if table_alias: 147 for p in exp.COLUMN_PARTS[1:]: 148 column.set(p, None) 149 column.set("table", table_alias) 150 151 return expression
def
qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, infer_csv_schemas: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
17def qualify_tables( 18 expression: E, 19 db: t.Optional[str | exp.Identifier] = None, 20 catalog: t.Optional[str | exp.Identifier] = None, 21 schema: t.Optional[Schema] = None, 22 infer_csv_schemas: bool = False, 23 dialect: DialectType = None, 24) -> E: 25 """ 26 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 27 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 28 29 Examples: 30 >>> import sqlglot 31 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 32 >>> qualify_tables(expression, db="db").sql() 33 'SELECT 1 FROM db.tbl AS tbl' 34 >>> 35 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 36 >>> qualify_tables(expression).sql() 37 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 38 39 Args: 40 expression: Expression to qualify 41 db: Database name 42 catalog: Catalog name 43 schema: A schema to populate 44 infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas. 45 dialect: The dialect to parse catalog and schema into. 46 47 Returns: 48 The qualified expression. 49 """ 50 next_alias_name = name_sequence("_q_") 51 db = exp.parse_identifier(db, dialect=dialect) if db else None 52 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 53 54 def _qualify(table: exp.Table) -> None: 55 if isinstance(table.this, exp.Identifier): 56 if not table.args.get("db"): 57 table.set("db", db) 58 if not table.args.get("catalog") and table.args.get("db"): 59 table.set("catalog", catalog) 60 61 if (db or catalog) and not isinstance(expression, exp.Query): 62 for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): 63 if isinstance(node, exp.Table): 64 _qualify(node) 65 66 for scope in traverse_scope(expression): 67 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 68 if isinstance(derived_table, exp.Subquery): 69 unnested = derived_table.unnest() 70 if isinstance(unnested, exp.Table): 71 joins = unnested.args.pop("joins", None) 72 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 73 derived_table.this.set("joins", joins) 74 75 if not derived_table.args.get("alias"): 76 alias_ = next_alias_name() 77 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 78 scope.rename_source(None, alias_) 79 80 pivots = derived_table.args.get("pivots") 81 if pivots and not pivots[0].alias: 82 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 83 84 table_aliases = {} 85 86 for name, source in scope.sources.items(): 87 if isinstance(source, exp.Table): 88 pivots = source.args.get("pivots") 89 if not source.alias: 90 # Don't add the pivot's alias to the pivoted table, use the table's name instead 91 if pivots and pivots[0].alias == name: 92 name = source.name 93 94 # Mutates the source by attaching an alias to it 95 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 96 97 table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( 98 source.alias 99 ) 100 101 if pivots: 102 if not pivots[0].alias: 103 pivot_alias = next_alias_name() 104 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias))) 105 106 # This case corresponds to a pivoted CTE, we don't want to qualify that 107 if isinstance(scope.sources.get(source.alias_or_name), Scope): 108 continue 109 110 _qualify(source) 111 112 if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV): 113 with csv_reader(source.this) as reader: 114 header = next(reader) 115 columns = next(reader) 116 schema.add_table( 117 source, 118 {k: type(v).__name__ for k, v in zip(header, columns)}, 119 match_depth=False, 120 ) 121 elif isinstance(source, Scope) and source.is_udtf: 122 udtf = source.expression 123 table_alias = udtf.args.get("alias") or exp.TableAlias( 124 this=exp.to_identifier(next_alias_name()) 125 ) 126 udtf.set("alias", table_alias) 127 128 if not table_alias.name: 129 table_alias.set("this", exp.to_identifier(next_alias_name())) 130 if isinstance(udtf, exp.Values) and not table_alias.columns: 131 for i, e in enumerate(udtf.expressions[0].expressions): 132 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 133 else: 134 for node in scope.walk(): 135 if ( 136 isinstance(node, exp.Table) 137 and not node.alias 138 and isinstance(node.parent, (exp.From, exp.Join)) 139 ): 140 # Mutates the table by attaching an alias to it 141 alias(node, node.name, copy=False, table=True) 142 143 for column in scope.columns: 144 if column.db: 145 table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) 146 147 if table_alias: 148 for p in exp.COLUMN_PARTS[1:]: 149 column.set(p, None) 150 column.set("table", table_alias) 151 152 return expression
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
- expression: Expression to qualify
- db: Database name
- catalog: Catalog name
- schema: A schema to populate
- infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
- dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.