Edit on GitHub

sqlglot.optimizer.qualify_tables

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import DialectType
  8from sqlglot.helper import csv_reader, name_sequence
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10from sqlglot.schema import Schema
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot._typing import E
 14
 15
 16def qualify_tables(
 17    expression: E,
 18    db: t.Optional[str | exp.Identifier] = None,
 19    catalog: t.Optional[str | exp.Identifier] = None,
 20    schema: t.Optional[Schema] = None,
 21    infer_csv_schemas: bool = False,
 22    dialect: DialectType = None,
 23) -> E:
 24    """
 25    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 26    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 27
 28    Examples:
 29        >>> import sqlglot
 30        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 31        >>> qualify_tables(expression, db="db").sql()
 32        'SELECT 1 FROM db.tbl AS tbl'
 33        >>>
 34        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 35        >>> qualify_tables(expression).sql()
 36        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 37
 38    Args:
 39        expression: Expression to qualify
 40        db: Database name
 41        catalog: Catalog name
 42        schema: A schema to populate
 43        infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
 44        dialect: The dialect to parse catalog and schema into.
 45
 46    Returns:
 47        The qualified expression.
 48    """
 49    next_alias_name = name_sequence("_q_")
 50    db = exp.parse_identifier(db, dialect=dialect) if db else None
 51    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 52
 53    def _qualify(table: exp.Table) -> None:
 54        if isinstance(table.this, exp.Identifier):
 55            if not table.args.get("db"):
 56                table.set("db", db)
 57            if not table.args.get("catalog") and table.args.get("db"):
 58                table.set("catalog", catalog)
 59
 60    if (db or catalog) and not isinstance(expression, exp.Query):
 61        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 62            if isinstance(node, exp.Table):
 63                _qualify(node)
 64
 65    for scope in traverse_scope(expression):
 66        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 67            if isinstance(derived_table, exp.Subquery):
 68                unnested = derived_table.unnest()
 69                if isinstance(unnested, exp.Table):
 70                    joins = unnested.args.pop("joins", None)
 71                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 72                    derived_table.this.set("joins", joins)
 73
 74            if not derived_table.args.get("alias"):
 75                alias_ = next_alias_name()
 76                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 77                scope.rename_source(None, alias_)
 78
 79            pivots = derived_table.args.get("pivots")
 80            if pivots and not pivots[0].alias:
 81                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 82
 83        table_aliases = {}
 84
 85        for name, source in scope.sources.items():
 86            if isinstance(source, exp.Table):
 87                pivots = source.args.get("pivots")
 88                if not source.alias:
 89                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 90                    if pivots and pivots[0].alias == name:
 91                        name = source.name
 92
 93                    # Mutates the source by attaching an alias to it
 94                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 95
 96                table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
 97                    source.alias
 98                )
 99
100                _qualify(source)
101
102                if pivots and not pivots[0].alias:
103                    pivots[0].set(
104                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
105                    )
106
107                if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
108                    with csv_reader(source.this) as reader:
109                        header = next(reader)
110                        columns = next(reader)
111                        schema.add_table(
112                            source,
113                            {k: type(v).__name__ for k, v in zip(header, columns)},
114                            match_depth=False,
115                        )
116            elif isinstance(source, Scope) and source.is_udtf:
117                udtf = source.expression
118                table_alias = udtf.args.get("alias") or exp.TableAlias(
119                    this=exp.to_identifier(next_alias_name())
120                )
121                udtf.set("alias", table_alias)
122
123                if not table_alias.name:
124                    table_alias.set("this", exp.to_identifier(next_alias_name()))
125                if isinstance(udtf, exp.Values) and not table_alias.columns:
126                    for i, e in enumerate(udtf.expressions[0].expressions):
127                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
128            else:
129                for node in scope.walk():
130                    if (
131                        isinstance(node, exp.Table)
132                        and not node.alias
133                        and isinstance(node.parent, (exp.From, exp.Join))
134                    ):
135                        # Mutates the table by attaching an alias to it
136                        alias(node, node.name, copy=False, table=True)
137
138        for column in scope.columns:
139            if column.db:
140                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
141
142                if table_alias:
143                    for p in exp.COLUMN_PARTS[1:]:
144                        column.set(p, None)
145                    column.set("table", table_alias)
146
147    return expression
def qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, infer_csv_schemas: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> ~E:
 17def qualify_tables(
 18    expression: E,
 19    db: t.Optional[str | exp.Identifier] = None,
 20    catalog: t.Optional[str | exp.Identifier] = None,
 21    schema: t.Optional[Schema] = None,
 22    infer_csv_schemas: bool = False,
 23    dialect: DialectType = None,
 24) -> E:
 25    """
 26    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 27    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 28
 29    Examples:
 30        >>> import sqlglot
 31        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 32        >>> qualify_tables(expression, db="db").sql()
 33        'SELECT 1 FROM db.tbl AS tbl'
 34        >>>
 35        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 36        >>> qualify_tables(expression).sql()
 37        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 38
 39    Args:
 40        expression: Expression to qualify
 41        db: Database name
 42        catalog: Catalog name
 43        schema: A schema to populate
 44        infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
 45        dialect: The dialect to parse catalog and schema into.
 46
 47    Returns:
 48        The qualified expression.
 49    """
 50    next_alias_name = name_sequence("_q_")
 51    db = exp.parse_identifier(db, dialect=dialect) if db else None
 52    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 53
 54    def _qualify(table: exp.Table) -> None:
 55        if isinstance(table.this, exp.Identifier):
 56            if not table.args.get("db"):
 57                table.set("db", db)
 58            if not table.args.get("catalog") and table.args.get("db"):
 59                table.set("catalog", catalog)
 60
 61    if (db or catalog) and not isinstance(expression, exp.Query):
 62        for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
 63            if isinstance(node, exp.Table):
 64                _qualify(node)
 65
 66    for scope in traverse_scope(expression):
 67        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 68            if isinstance(derived_table, exp.Subquery):
 69                unnested = derived_table.unnest()
 70                if isinstance(unnested, exp.Table):
 71                    joins = unnested.args.pop("joins", None)
 72                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 73                    derived_table.this.set("joins", joins)
 74
 75            if not derived_table.args.get("alias"):
 76                alias_ = next_alias_name()
 77                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 78                scope.rename_source(None, alias_)
 79
 80            pivots = derived_table.args.get("pivots")
 81            if pivots and not pivots[0].alias:
 82                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 83
 84        table_aliases = {}
 85
 86        for name, source in scope.sources.items():
 87            if isinstance(source, exp.Table):
 88                pivots = source.args.get("pivots")
 89                if not source.alias:
 90                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 91                    if pivots and pivots[0].alias == name:
 92                        name = source.name
 93
 94                    # Mutates the source by attaching an alias to it
 95                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 96
 97                table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
 98                    source.alias
 99                )
100
101                _qualify(source)
102
103                if pivots and not pivots[0].alias:
104                    pivots[0].set(
105                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
106                    )
107
108                if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
109                    with csv_reader(source.this) as reader:
110                        header = next(reader)
111                        columns = next(reader)
112                        schema.add_table(
113                            source,
114                            {k: type(v).__name__ for k, v in zip(header, columns)},
115                            match_depth=False,
116                        )
117            elif isinstance(source, Scope) and source.is_udtf:
118                udtf = source.expression
119                table_alias = udtf.args.get("alias") or exp.TableAlias(
120                    this=exp.to_identifier(next_alias_name())
121                )
122                udtf.set("alias", table_alias)
123
124                if not table_alias.name:
125                    table_alias.set("this", exp.to_identifier(next_alias_name()))
126                if isinstance(udtf, exp.Values) and not table_alias.columns:
127                    for i, e in enumerate(udtf.expressions[0].expressions):
128                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
129            else:
130                for node in scope.walk():
131                    if (
132                        isinstance(node, exp.Table)
133                        and not node.alias
134                        and isinstance(node.parent, (exp.From, exp.Join))
135                    ):
136                        # Mutates the table by attaching an alias to it
137                        alias(node, node.name, copy=False, table=True)
138
139        for column in scope.columns:
140            if column.db:
141                table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1]))
142
143                if table_alias:
144                    for p in exp.COLUMN_PARTS[1:]:
145                        column.set(p, None)
146                    column.set("table", table_alias)
147
148    return expression

Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
>>> qualify_tables(expression).sql()
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
  • expression: Expression to qualify
  • db: Database name
  • catalog: Catalog name
  • schema: A schema to populate
  • infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
  • dialect: The dialect to parse catalog and schema into.
Returns:

The qualified expression.