Edit on GitHub

sqlglot.optimizer.qualify_columns

   1from __future__ import annotations
   2
   3import itertools
   4import typing as t
   5
   6from sqlglot import alias, exp
   7from sqlglot.dialects.dialect import Dialect, DialectType
   8from sqlglot.errors import OptimizeError
   9from sqlglot.helper import seq_get, SingleValuedMapping
  10from sqlglot.optimizer.annotate_types import TypeAnnotator
  11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
  12from sqlglot.optimizer.simplify import simplify_parens
  13from sqlglot.schema import Schema, ensure_schema
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot._typing import E
  17
  18
  19def qualify_columns(
  20    expression: exp.Expression,
  21    schema: t.Dict | Schema,
  22    expand_alias_refs: bool = True,
  23    expand_stars: bool = True,
  24    infer_schema: t.Optional[bool] = None,
  25    allow_partial_qualification: bool = False,
  26    dialect: DialectType = None,
  27) -> exp.Expression:
  28    """
  29    Rewrite sqlglot AST to have fully qualified columns.
  30
  31    Example:
  32        >>> import sqlglot
  33        >>> schema = {"tbl": {"col": "INT"}}
  34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
  35        >>> qualify_columns(expression, schema).sql()
  36        'SELECT tbl.col AS col FROM tbl'
  37
  38    Args:
  39        expression: Expression to qualify.
  40        schema: Database schema.
  41        expand_alias_refs: Whether to expand references to aliases.
  42        expand_stars: Whether to expand star queries. This is a necessary step
  43            for most of the optimizer's rules to work; do not set to False unless you
  44            know what you're doing!
  45        infer_schema: Whether to infer the schema if missing.
  46        allow_partial_qualification: Whether to allow partial qualification.
  47
  48    Returns:
  49        The qualified expression.
  50
  51    Notes:
  52        - Currently only handles a single PIVOT or UNPIVOT operator
  53    """
  54    schema = ensure_schema(schema, dialect=dialect)
  55    annotator = TypeAnnotator(schema)
  56    infer_schema = schema.empty if infer_schema is None else infer_schema
  57    dialect = Dialect.get_or_raise(schema.dialect)
  58    pseudocolumns = dialect.PSEUDOCOLUMNS
  59    bigquery = dialect == "bigquery"
  60
  61    for scope in traverse_scope(expression):
  62        scope_expression = scope.expression
  63        is_select = isinstance(scope_expression, exp.Select)
  64
  65        if is_select and scope_expression.args.get("connect"):
  66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
  67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
  68            scope_expression.transform(
  69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
  70                copy=False,
  71            )
  72            scope.clear_cache()
  73
  74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
  75        _pop_table_column_aliases(scope.ctes)
  76        _pop_table_column_aliases(scope.derived_tables)
  77        using_column_tables = _expand_using(scope, resolver)
  78
  79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
  80            _expand_alias_refs(
  81                scope,
  82                resolver,
  83                dialect,
  84                expand_only_groupby=bigquery,
  85            )
  86
  87        _convert_columns_to_dots(scope, resolver)
  88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
  89
  90        if not schema.empty and expand_alias_refs:
  91            _expand_alias_refs(scope, resolver, dialect)
  92
  93        if is_select:
  94            if expand_stars:
  95                _expand_stars(
  96                    scope,
  97                    resolver,
  98                    using_column_tables,
  99                    pseudocolumns,
 100                    annotator,
 101                )
 102            qualify_outputs(scope)
 103
 104        _expand_group_by(scope, dialect)
 105
 106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
 107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
 108        _expand_order_by_and_distinct_on(scope, resolver)
 109
 110        if bigquery:
 111            annotator.annotate_scope(scope)
 112
 113    return expression
 114
 115
 116def validate_qualify_columns(expression: E) -> E:
 117    """Raise an `OptimizeError` if any columns aren't qualified"""
 118    all_unqualified_columns = []
 119    for scope in traverse_scope(expression):
 120        if isinstance(scope.expression, exp.Select):
 121            unqualified_columns = scope.unqualified_columns
 122
 123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 124                column = scope.external_columns[0]
 125                for_table = f" for table: '{column.table}'" if column.table else ""
 126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 127
 128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 131                # this list here to ensure those in the former category will be excluded.
 132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 134
 135            all_unqualified_columns.extend(unqualified_columns)
 136
 137    if all_unqualified_columns:
 138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
 139
 140    return expression
 141
 142
 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
 144    name_columns = [
 145        field.this
 146        for field in unpivot.fields
 147        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
 148    ]
 149    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
 150
 151    return itertools.chain(name_columns, value_columns)
 152
 153
 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 155    """
 156    Remove table column aliases.
 157
 158    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 159    """
 160    for derived_table in derived_tables:
 161        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
 162            continue
 163        table_alias = derived_table.args.get("alias")
 164        if table_alias:
 165            table_alias.args.pop("columns", None)
 166
 167
 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 169    columns = {}
 170
 171    def _update_source_columns(source_name: str) -> None:
 172        for column_name in resolver.get_source_columns(source_name):
 173            if column_name not in columns:
 174                columns[column_name] = source_name
 175
 176    joins = list(scope.find_all(exp.Join))
 177    names = {join.alias_or_name for join in joins}
 178    ordered = [key for key in scope.selected_sources if key not in names]
 179
 180    if names and not ordered:
 181        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
 182
 183    # Mapping of automatically joined column names to an ordered set of source names (dict).
 184    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
 185
 186    for source_name in ordered:
 187        _update_source_columns(source_name)
 188
 189    for i, join in enumerate(joins):
 190        source_table = ordered[-1]
 191        if source_table:
 192            _update_source_columns(source_table)
 193
 194        join_table = join.alias_or_name
 195        ordered.append(join_table)
 196
 197        using = join.args.get("using")
 198        if not using:
 199            continue
 200
 201        join_columns = resolver.get_source_columns(join_table)
 202        conditions = []
 203        using_identifier_count = len(using)
 204        is_semi_or_anti_join = join.is_semi_or_anti_join
 205
 206        for identifier in using:
 207            identifier = identifier.name
 208            table = columns.get(identifier)
 209
 210            if not table or identifier not in join_columns:
 211                if (columns and "*" not in columns) and join_columns:
 212                    raise OptimizeError(f"Cannot automatically join: {identifier}")
 213
 214            table = table or source_table
 215
 216            if i == 0 or using_identifier_count == 1:
 217                lhs: exp.Expression = exp.column(identifier, table=table)
 218            else:
 219                coalesce_columns = [
 220                    exp.column(identifier, table=t)
 221                    for t in ordered[:-1]
 222                    if identifier in resolver.get_source_columns(t)
 223                ]
 224                if len(coalesce_columns) > 1:
 225                    lhs = exp.func("coalesce", *coalesce_columns)
 226                else:
 227                    lhs = exp.column(identifier, table=table)
 228
 229            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
 230
 231            # Set all values in the dict to None, because we only care about the key ordering
 232            tables = column_tables.setdefault(identifier, {})
 233
 234            # Do not update the dict if this was a SEMI/ANTI join in
 235            # order to avoid generating COALESCE columns for this join pair
 236            if not is_semi_or_anti_join:
 237                if table not in tables:
 238                    tables[table] = None
 239                if join_table not in tables:
 240                    tables[join_table] = None
 241
 242        join.args.pop("using")
 243        join.set("on", exp.and_(*conditions, copy=False))
 244
 245    if column_tables:
 246        for column in scope.columns:
 247            if not column.table and column.name in column_tables:
 248                tables = column_tables[column.name]
 249                coalesce_args = [exp.column(column.name, table=table) for table in tables]
 250                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
 251
 252                if isinstance(column.parent, exp.Select):
 253                    # Ensure the USING column keeps its name if it's projected
 254                    replacement = alias(replacement, alias=column.name, copy=False)
 255                elif isinstance(column.parent, exp.Struct):
 256                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
 257                    replacement = exp.PropertyEQ(
 258                        this=exp.to_identifier(column.name), expression=replacement
 259                    )
 260
 261                scope.replace(column, replacement)
 262
 263    return column_tables
 264
 265
 266def _expand_alias_refs(
 267    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
 268) -> None:
 269    """
 270    Expand references to aliases.
 271    Example:
 272        SELECT y.foo AS bar, bar * 2 AS baz FROM y
 273     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
 274    """
 275    expression = scope.expression
 276
 277    if not isinstance(expression, exp.Select) or dialect == "oracle":
 278        return
 279
 280    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
 281    projections = {s.alias_or_name for s in expression.selects}
 282
 283    def replace_columns(
 284        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
 285    ) -> None:
 286        is_group_by = isinstance(node, exp.Group)
 287        is_having = isinstance(node, exp.Having)
 288        if not node or (expand_only_groupby and not is_group_by):
 289            return
 290
 291        for column in walk_in_scope(node, prune=lambda node: node.is_star):
 292            if not isinstance(column, exp.Column):
 293                continue
 294
 295            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
 296            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
 297            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
 298            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
 299            if expand_only_groupby and is_group_by and column.parent is not node:
 300                continue
 301
 302            skip_replace = False
 303            table = resolver.get_table(column.name) if resolve_table and not column.table else None
 304            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
 305
 306            if alias_expr:
 307                skip_replace = bool(
 308                    alias_expr.find(exp.AggFunc)
 309                    and column.find_ancestor(exp.AggFunc)
 310                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
 311                )
 312
 313                # BigQuery's having clause gets confused if an alias matches a source.
 314                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
 315                # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table
 316                if is_having and dialect == "bigquery":
 317                    skip_replace = skip_replace or any(
 318                        node.parts[0].name in projections
 319                        for node in alias_expr.find_all(exp.Column)
 320                    )
 321
 322            if table and (not alias_expr or skip_replace):
 323                column.set("table", table)
 324            elif not column.table and alias_expr and not skip_replace:
 325                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
 326                    if literal_index:
 327                        column.replace(exp.Literal.number(i))
 328                else:
 329                    column = column.replace(exp.paren(alias_expr))
 330                    simplified = simplify_parens(column)
 331                    if simplified is not column:
 332                        column.replace(simplified)
 333
 334    for i, projection in enumerate(expression.selects):
 335        replace_columns(projection)
 336        if isinstance(projection, exp.Alias):
 337            alias_to_expression[projection.alias] = (projection.this, i + 1)
 338
 339    parent_scope = scope
 340    while parent_scope.is_union:
 341        parent_scope = parent_scope.parent
 342
 343    # We shouldn't expand aliases if they match the recursive CTE's columns
 344    if parent_scope.is_cte:
 345        cte = parent_scope.expression.parent
 346        if cte.find_ancestor(exp.With).recursive:
 347            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
 348                alias_to_expression.pop(recursive_cte_column.output_name, None)
 349
 350    replace_columns(expression.args.get("where"))
 351    replace_columns(expression.args.get("group"), literal_index=True)
 352    replace_columns(expression.args.get("having"), resolve_table=True)
 353    replace_columns(expression.args.get("qualify"), resolve_table=True)
 354
 355    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
 356    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
 357    if dialect == "snowflake":
 358        for join in expression.args.get("joins") or []:
 359            replace_columns(join)
 360
 361    scope.clear_cache()
 362
 363
 364def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
 365    expression = scope.expression
 366    group = expression.args.get("group")
 367    if not group:
 368        return
 369
 370    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
 371    expression.set("group", group)
 372
 373
 374def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
 375    for modifier_key in ("order", "distinct"):
 376        modifier = scope.expression.args.get(modifier_key)
 377        if isinstance(modifier, exp.Distinct):
 378            modifier = modifier.args.get("on")
 379
 380        if not isinstance(modifier, exp.Expression):
 381            continue
 382
 383        modifier_expressions = modifier.expressions
 384        if modifier_key == "order":
 385            modifier_expressions = [ordered.this for ordered in modifier_expressions]
 386
 387        for original, expanded in zip(
 388            modifier_expressions,
 389            _expand_positional_references(
 390                scope, modifier_expressions, resolver.schema.dialect, alias=True
 391            ),
 392        ):
 393            for agg in original.find_all(exp.AggFunc):
 394                for col in agg.find_all(exp.Column):
 395                    if not col.table:
 396                        col.set("table", resolver.get_table(col.name))
 397
 398            original.replace(expanded)
 399
 400        if scope.expression.args.get("group"):
 401            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
 402
 403            for expression in modifier_expressions:
 404                expression.replace(
 405                    exp.to_identifier(_select_by_pos(scope, expression).alias)
 406                    if expression.is_int
 407                    else selects.get(expression, expression)
 408                )
 409
 410
 411def _expand_positional_references(
 412    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
 413) -> t.List[exp.Expression]:
 414    new_nodes: t.List[exp.Expression] = []
 415    ambiguous_projections = None
 416
 417    for node in expressions:
 418        if node.is_int:
 419            select = _select_by_pos(scope, t.cast(exp.Literal, node))
 420
 421            if alias:
 422                new_nodes.append(exp.column(select.args["alias"].copy()))
 423            else:
 424                select = select.this
 425
 426                if dialect == "bigquery":
 427                    if ambiguous_projections is None:
 428                        # When a projection name is also a source name and it is referenced in the
 429                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
 430                        ambiguous_projections = {
 431                            s.alias_or_name
 432                            for s in scope.expression.selects
 433                            if s.alias_or_name in scope.selected_sources
 434                        }
 435
 436                    ambiguous = any(
 437                        column.parts[0].name in ambiguous_projections
 438                        for column in select.find_all(exp.Column)
 439                    )
 440                else:
 441                    ambiguous = False
 442
 443                if (
 444                    isinstance(select, exp.CONSTANTS)
 445                    or select.find(exp.Explode, exp.Unnest)
 446                    or ambiguous
 447                ):
 448                    new_nodes.append(node)
 449                else:
 450                    new_nodes.append(select.copy())
 451        else:
 452            new_nodes.append(node)
 453
 454    return new_nodes
 455
 456
 457def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
 458    try:
 459        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
 460    except IndexError:
 461        raise OptimizeError(f"Unknown output column: {node.name}")
 462
 463
 464def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
 465    """
 466    Converts `Column` instances that represent struct field lookup into chained `Dots`.
 467
 468    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
 469    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
 470    """
 471    converted = False
 472    for column in itertools.chain(scope.columns, scope.stars):
 473        if isinstance(column, exp.Dot):
 474            continue
 475
 476        column_table: t.Optional[str | exp.Identifier] = column.table
 477        if (
 478            column_table
 479            and column_table not in scope.sources
 480            and (
 481                not scope.parent
 482                or column_table not in scope.parent.sources
 483                or not scope.is_correlated_subquery
 484            )
 485        ):
 486            root, *parts = column.parts
 487
 488            if root.name in scope.sources:
 489                # The struct is already qualified, but we still need to change the AST
 490                column_table = root
 491                root, *parts = parts
 492            else:
 493                column_table = resolver.get_table(root.name)
 494
 495            if column_table:
 496                converted = True
 497                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
 498
 499    if converted:
 500        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
 501        # a `for column in scope.columns` iteration, even though they shouldn't be
 502        scope.clear_cache()
 503
 504
 505def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
 506    """Disambiguate columns, ensuring each column specifies a source"""
 507    for column in scope.columns:
 508        column_table = column.table
 509        column_name = column.name
 510
 511        if column_table and column_table in scope.sources:
 512            source_columns = resolver.get_source_columns(column_table)
 513            if (
 514                not allow_partial_qualification
 515                and source_columns
 516                and column_name not in source_columns
 517                and "*" not in source_columns
 518            ):
 519                raise OptimizeError(f"Unknown column: {column_name}")
 520
 521        if not column_table:
 522            if scope.pivots and not column.find_ancestor(exp.Pivot):
 523                # If the column is under the Pivot expression, we need to qualify it
 524                # using the name of the pivoted source instead of the pivot's alias
 525                column.set("table", exp.to_identifier(scope.pivots[0].alias))
 526                continue
 527
 528            # column_table can be a '' because bigquery unnest has no table alias
 529            column_table = resolver.get_table(column_name)
 530            if column_table:
 531                column.set("table", column_table)
 532            elif (
 533                resolver.schema.dialect == "bigquery"
 534                and len(column.parts) == 1
 535                and column_name in scope.selected_sources
 536            ):
 537                # BigQuery allows tables to be referenced as columns, treating them as structs
 538                scope.replace(column, exp.TableColumn(this=column.this))
 539
 540    for pivot in scope.pivots:
 541        for column in pivot.find_all(exp.Column):
 542            if not column.table and column.name in resolver.all_columns:
 543                column_table = resolver.get_table(column.name)
 544                if column_table:
 545                    column.set("table", column_table)
 546
 547
 548def _expand_struct_stars(
 549    expression: exp.Dot,
 550) -> t.List[exp.Alias]:
 551    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
 552
 553    dot_column = t.cast(exp.Column, expression.find(exp.Column))
 554    if not dot_column.is_type(exp.DataType.Type.STRUCT):
 555        return []
 556
 557    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
 558    dot_column = dot_column.copy()
 559    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
 560
 561    # First part is the table name and last part is the star so they can be dropped
 562    dot_parts = expression.parts[1:-1]
 563
 564    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
 565    for part in dot_parts[1:]:
 566        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 567            # Unable to expand star unless all fields are named
 568            if not isinstance(field.this, exp.Identifier):
 569                return []
 570
 571            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
 572                starting_struct = field
 573                break
 574        else:
 575            # There is no matching field in the struct
 576            return []
 577
 578    taken_names = set()
 579    new_selections = []
 580
 581    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 582        name = field.name
 583
 584        # Ambiguous or anonymous fields can't be expanded
 585        if name in taken_names or not isinstance(field.this, exp.Identifier):
 586            return []
 587
 588        taken_names.add(name)
 589
 590        this = field.this.copy()
 591        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
 592        new_column = exp.column(
 593            t.cast(exp.Identifier, root),
 594            table=dot_column.args.get("table"),
 595            fields=t.cast(t.List[exp.Identifier], parts),
 596        )
 597        new_selections.append(alias(new_column, this, copy=False))
 598
 599    return new_selections
 600
 601
 602def _expand_stars(
 603    scope: Scope,
 604    resolver: Resolver,
 605    using_column_tables: t.Dict[str, t.Any],
 606    pseudocolumns: t.Set[str],
 607    annotator: TypeAnnotator,
 608) -> None:
 609    """Expand stars to lists of column selections"""
 610
 611    new_selections: t.List[exp.Expression] = []
 612    except_columns: t.Dict[int, t.Set[str]] = {}
 613    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
 614    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
 615
 616    coalesced_columns = set()
 617    dialect = resolver.schema.dialect
 618
 619    pivot_output_columns = None
 620    pivot_exclude_columns: t.Set[str] = set()
 621
 622    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
 623    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
 624        if pivot.unpivot:
 625            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
 626
 627            for field in pivot.fields:
 628                if isinstance(field, exp.In):
 629                    pivot_exclude_columns.update(
 630                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
 631                    )
 632
 633        else:
 634            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
 635
 636            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
 637            if not pivot_output_columns:
 638                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
 639
 640    is_bigquery = dialect == "bigquery"
 641    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
 642        # Found struct expansion, annotate scope ahead of time
 643        annotator.annotate_scope(scope)
 644
 645    for expression in scope.expression.selects:
 646        tables = []
 647        if isinstance(expression, exp.Star):
 648            tables.extend(scope.selected_sources)
 649            _add_except_columns(expression, tables, except_columns)
 650            _add_replace_columns(expression, tables, replace_columns)
 651            _add_rename_columns(expression, tables, rename_columns)
 652        elif expression.is_star:
 653            if not isinstance(expression, exp.Dot):
 654                tables.append(expression.table)
 655                _add_except_columns(expression.this, tables, except_columns)
 656                _add_replace_columns(expression.this, tables, replace_columns)
 657                _add_rename_columns(expression.this, tables, rename_columns)
 658            elif is_bigquery:
 659                struct_fields = _expand_struct_stars(expression)
 660                if struct_fields:
 661                    new_selections.extend(struct_fields)
 662                    continue
 663
 664        if not tables:
 665            new_selections.append(expression)
 666            continue
 667
 668        for table in tables:
 669            if table not in scope.sources:
 670                raise OptimizeError(f"Unknown table: {table}")
 671
 672            columns = resolver.get_source_columns(table, only_visible=True)
 673            columns = columns or scope.outer_columns
 674
 675            if pseudocolumns:
 676                columns = [name for name in columns if name.upper() not in pseudocolumns]
 677
 678            if not columns or "*" in columns:
 679                return
 680
 681            table_id = id(table)
 682            columns_to_exclude = except_columns.get(table_id) or set()
 683            renamed_columns = rename_columns.get(table_id, {})
 684            replaced_columns = replace_columns.get(table_id, {})
 685
 686            if pivot:
 687                if pivot_output_columns and pivot_exclude_columns:
 688                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
 689                    pivot_columns.extend(pivot_output_columns)
 690                else:
 691                    pivot_columns = pivot.alias_column_names
 692
 693                if pivot_columns:
 694                    new_selections.extend(
 695                        alias(exp.column(name, table=pivot.alias), name, copy=False)
 696                        for name in pivot_columns
 697                        if name not in columns_to_exclude
 698                    )
 699                    continue
 700
 701            for name in columns:
 702                if name in columns_to_exclude or name in coalesced_columns:
 703                    continue
 704                if name in using_column_tables and table in using_column_tables[name]:
 705                    coalesced_columns.add(name)
 706                    tables = using_column_tables[name]
 707                    coalesce_args = [exp.column(name, table=table) for table in tables]
 708
 709                    new_selections.append(
 710                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
 711                    )
 712                else:
 713                    alias_ = renamed_columns.get(name, name)
 714                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
 715                    new_selections.append(
 716                        alias(selection_expr, alias_, copy=False)
 717                        if alias_ != name
 718                        else selection_expr
 719                    )
 720
 721    # Ensures we don't overwrite the initial selections with an empty list
 722    if new_selections and isinstance(scope.expression, exp.Select):
 723        scope.expression.set("expressions", new_selections)
 724
 725
 726def _add_except_columns(
 727    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
 728) -> None:
 729    except_ = expression.args.get("except")
 730
 731    if not except_:
 732        return
 733
 734    columns = {e.name for e in except_}
 735
 736    for table in tables:
 737        except_columns[id(table)] = columns
 738
 739
 740def _add_rename_columns(
 741    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
 742) -> None:
 743    rename = expression.args.get("rename")
 744
 745    if not rename:
 746        return
 747
 748    columns = {e.this.name: e.alias for e in rename}
 749
 750    for table in tables:
 751        rename_columns[id(table)] = columns
 752
 753
 754def _add_replace_columns(
 755    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
 756) -> None:
 757    replace = expression.args.get("replace")
 758
 759    if not replace:
 760        return
 761
 762    columns = {e.alias: e for e in replace}
 763
 764    for table in tables:
 765        replace_columns[id(table)] = columns
 766
 767
 768def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
 769    """Ensure all output columns are aliased"""
 770    if isinstance(scope_or_expression, exp.Expression):
 771        scope = build_scope(scope_or_expression)
 772        if not isinstance(scope, Scope):
 773            return
 774    else:
 775        scope = scope_or_expression
 776
 777    new_selections = []
 778    for i, (selection, aliased_column) in enumerate(
 779        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
 780    ):
 781        if selection is None or isinstance(selection, exp.QueryTransform):
 782            break
 783
 784        if isinstance(selection, exp.Subquery):
 785            if not selection.output_name:
 786                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
 787        elif not isinstance(selection, exp.Alias) and not selection.is_star:
 788            selection = alias(
 789                selection,
 790                alias=selection.output_name or f"_col_{i}",
 791                copy=False,
 792            )
 793        if aliased_column:
 794            selection.set("alias", exp.to_identifier(aliased_column))
 795
 796        new_selections.append(selection)
 797
 798    if new_selections and isinstance(scope.expression, exp.Select):
 799        scope.expression.set("expressions", new_selections)
 800
 801
 802def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
 803    """Makes sure all identifiers that need to be quoted are quoted."""
 804    return expression.transform(
 805        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
 806    )  # type: ignore
 807
 808
 809def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
 810    """
 811    Pushes down the CTE alias columns into the projection,
 812
 813    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
 814
 815    Example:
 816        >>> import sqlglot
 817        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
 818        >>> pushdown_cte_alias_columns(expression).sql()
 819        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
 820
 821    Args:
 822        expression: Expression to pushdown.
 823
 824    Returns:
 825        The expression with the CTE aliases pushed down into the projection.
 826    """
 827    for cte in expression.find_all(exp.CTE):
 828        if cte.alias_column_names:
 829            new_expressions = []
 830            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
 831                if isinstance(projection, exp.Alias):
 832                    projection.set("alias", _alias)
 833                else:
 834                    projection = alias(projection, alias=_alias)
 835                new_expressions.append(projection)
 836            cte.this.set("expressions", new_expressions)
 837
 838    return expression
 839
 840
 841class Resolver:
 842    """
 843    Helper for resolving columns.
 844
 845    This is a class so we can lazily load some things and easily share them across functions.
 846    """
 847
 848    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 849        self.scope = scope
 850        self.schema = schema
 851        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 852        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 853        self._all_columns: t.Optional[t.Set[str]] = None
 854        self._infer_schema = infer_schema
 855        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 856
 857    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 858        """
 859        Get the table for a column name.
 860
 861        Args:
 862            column_name: The column name to find the table for.
 863        Returns:
 864            The table name if it can be found/inferred.
 865        """
 866        if self._unambiguous_columns is None:
 867            self._unambiguous_columns = self._get_unambiguous_columns(
 868                self._get_all_source_columns()
 869            )
 870
 871        table_name = self._unambiguous_columns.get(column_name)
 872
 873        if not table_name and self._infer_schema:
 874            sources_without_schema = tuple(
 875                source
 876                for source, columns in self._get_all_source_columns().items()
 877                if not columns or "*" in columns
 878            )
 879            if len(sources_without_schema) == 1:
 880                table_name = sources_without_schema[0]
 881
 882        if table_name not in self.scope.selected_sources:
 883            return exp.to_identifier(table_name)
 884
 885        node, _ = self.scope.selected_sources.get(table_name)
 886
 887        if isinstance(node, exp.Query):
 888            while node and node.alias != table_name:
 889                node = node.parent
 890
 891        node_alias = node.args.get("alias")
 892        if node_alias:
 893            return exp.to_identifier(node_alias.this)
 894
 895        return exp.to_identifier(table_name)
 896
 897    @property
 898    def all_columns(self) -> t.Set[str]:
 899        """All available columns of all sources in this scope"""
 900        if self._all_columns is None:
 901            self._all_columns = {
 902                column for columns in self._get_all_source_columns().values() for column in columns
 903            }
 904        return self._all_columns
 905
 906    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 907        """Resolve the source columns for a given source `name`."""
 908        cache_key = (name, only_visible)
 909        if cache_key not in self._get_source_columns_cache:
 910            if name not in self.scope.sources:
 911                raise OptimizeError(f"Unknown table: {name}")
 912
 913            source = self.scope.sources[name]
 914
 915            if isinstance(source, exp.Table):
 916                columns = self.schema.column_names(source, only_visible)
 917            elif isinstance(source, Scope) and isinstance(
 918                source.expression, (exp.Values, exp.Unnest)
 919            ):
 920                columns = source.expression.named_selects
 921
 922                # in bigquery, unnest structs are automatically scoped as tables, so you can
 923                # directly select a struct field in a query.
 924                # this handles the case where the unnest is statically defined.
 925                if self.schema.dialect == "bigquery":
 926                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 927                        for k in source.expression.type.expressions:  # type: ignore
 928                            columns.append(k.name)
 929            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
 930                set_op = source.expression
 931
 932                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
 933                on_column_list = set_op.args.get("on")
 934
 935                if on_column_list:
 936                    # The resulting columns are the columns in the ON clause:
 937                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
 938                    columns = [col.name for col in on_column_list]
 939                elif set_op.side or set_op.kind:
 940                    side = set_op.side
 941                    kind = set_op.kind
 942
 943                    left = set_op.left.named_selects
 944                    right = set_op.right.named_selects
 945
 946                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
 947                    if side == "LEFT":
 948                        columns = left
 949                    elif side == "FULL":
 950                        columns = list(dict.fromkeys(left + right))
 951                    elif kind == "INNER":
 952                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
 953                else:
 954                    columns = set_op.named_selects
 955            else:
 956                select = seq_get(source.expression.selects, 0)
 957
 958                if isinstance(select, exp.QueryTransform):
 959                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
 960                    schema = select.args.get("schema")
 961                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
 962                else:
 963                    columns = source.expression.named_selects
 964
 965            node, _ = self.scope.selected_sources.get(name) or (None, None)
 966            if isinstance(node, Scope):
 967                column_aliases = node.expression.alias_column_names
 968            elif isinstance(node, exp.Expression):
 969                column_aliases = node.alias_column_names
 970            else:
 971                column_aliases = []
 972
 973            if column_aliases:
 974                # If the source's columns are aliased, their aliases shadow the corresponding column names.
 975                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
 976                columns = [
 977                    alias or name
 978                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
 979                ]
 980
 981            self._get_source_columns_cache[cache_key] = columns
 982
 983        return self._get_source_columns_cache[cache_key]
 984
 985    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
 986        if self._source_columns is None:
 987            self._source_columns = {
 988                source_name: self.get_source_columns(source_name)
 989                for source_name, source in itertools.chain(
 990                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
 991                )
 992            }
 993        return self._source_columns
 994
 995    def _get_unambiguous_columns(
 996        self, source_columns: t.Dict[str, t.Sequence[str]]
 997    ) -> t.Mapping[str, str]:
 998        """
 999        Find all the unambiguous columns in sources.
1000
1001        Args:
1002            source_columns: Mapping of names to source columns.
1003
1004        Returns:
1005            Mapping of column name to source name.
1006        """
1007        if not source_columns:
1008            return {}
1009
1010        source_columns_pairs = list(source_columns.items())
1011
1012        first_table, first_columns = source_columns_pairs[0]
1013
1014        if len(source_columns_pairs) == 1:
1015            # Performance optimization - avoid copying first_columns if there is only one table.
1016            return SingleValuedMapping(first_columns, first_table)
1017
1018        unambiguous_columns = {col: first_table for col in first_columns}
1019        all_columns = set(unambiguous_columns)
1020
1021        for table, columns in source_columns_pairs[1:]:
1022            unique = set(columns)
1023            ambiguous = all_columns.intersection(unique)
1024            all_columns.update(columns)
1025
1026            for column in ambiguous:
1027                unambiguous_columns.pop(column, None)
1028            for column in unique.difference(ambiguous):
1029                unambiguous_columns[column] = table
1030
1031        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27    dialect: DialectType = None,
 28) -> exp.Expression:
 29    """
 30    Rewrite sqlglot AST to have fully qualified columns.
 31
 32    Example:
 33        >>> import sqlglot
 34        >>> schema = {"tbl": {"col": "INT"}}
 35        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 36        >>> qualify_columns(expression, schema).sql()
 37        'SELECT tbl.col AS col FROM tbl'
 38
 39    Args:
 40        expression: Expression to qualify.
 41        schema: Database schema.
 42        expand_alias_refs: Whether to expand references to aliases.
 43        expand_stars: Whether to expand star queries. This is a necessary step
 44            for most of the optimizer's rules to work; do not set to False unless you
 45            know what you're doing!
 46        infer_schema: Whether to infer the schema if missing.
 47        allow_partial_qualification: Whether to allow partial qualification.
 48
 49    Returns:
 50        The qualified expression.
 51
 52    Notes:
 53        - Currently only handles a single PIVOT or UNPIVOT operator
 54    """
 55    schema = ensure_schema(schema, dialect=dialect)
 56    annotator = TypeAnnotator(schema)
 57    infer_schema = schema.empty if infer_schema is None else infer_schema
 58    dialect = Dialect.get_or_raise(schema.dialect)
 59    pseudocolumns = dialect.PSEUDOCOLUMNS
 60    bigquery = dialect == "bigquery"
 61
 62    for scope in traverse_scope(expression):
 63        scope_expression = scope.expression
 64        is_select = isinstance(scope_expression, exp.Select)
 65
 66        if is_select and scope_expression.args.get("connect"):
 67            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 68            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 69            scope_expression.transform(
 70                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 71                copy=False,
 72            )
 73            scope.clear_cache()
 74
 75        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 76        _pop_table_column_aliases(scope.ctes)
 77        _pop_table_column_aliases(scope.derived_tables)
 78        using_column_tables = _expand_using(scope, resolver)
 79
 80        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 81            _expand_alias_refs(
 82                scope,
 83                resolver,
 84                dialect,
 85                expand_only_groupby=bigquery,
 86            )
 87
 88        _convert_columns_to_dots(scope, resolver)
 89        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106
107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
109        _expand_order_by_and_distinct_on(scope, resolver)
110
111        if bigquery:
112            annotator.annotate_scope(scope)
113
114    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E:
118    """Raise an `OptimizeError` if any columns aren't qualified"""
119    all_unqualified_columns = []
120    for scope in traverse_scope(expression):
121        if isinstance(scope.expression, exp.Select):
122            unqualified_columns = scope.unqualified_columns
123
124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125                column = scope.external_columns[0]
126                for_table = f" for table: '{column.table}'" if column.table else ""
127                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
128
129            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
130                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
131                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
132                # this list here to ensure those in the former category will be excluded.
133                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
134                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
135
136            all_unqualified_columns.extend(unqualified_columns)
137
138    if all_unqualified_columns:
139        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
140
141    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
769def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
770    """Ensure all output columns are aliased"""
771    if isinstance(scope_or_expression, exp.Expression):
772        scope = build_scope(scope_or_expression)
773        if not isinstance(scope, Scope):
774            return
775    else:
776        scope = scope_or_expression
777
778    new_selections = []
779    for i, (selection, aliased_column) in enumerate(
780        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
781    ):
782        if selection is None or isinstance(selection, exp.QueryTransform):
783            break
784
785        if isinstance(selection, exp.Subquery):
786            if not selection.output_name:
787                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
788        elif not isinstance(selection, exp.Alias) and not selection.is_star:
789            selection = alias(
790                selection,
791                alias=selection.output_name or f"_col_{i}",
792                copy=False,
793            )
794        if aliased_column:
795            selection.set("alias", exp.to_identifier(aliased_column))
796
797        new_selections.append(selection)
798
799    if new_selections and isinstance(scope.expression, exp.Select):
800        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
803def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
804    """Makes sure all identifiers that need to be quoted are quoted."""
805    return expression.transform(
806        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
807    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
810def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
811    """
812    Pushes down the CTE alias columns into the projection,
813
814    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
815
816    Example:
817        >>> import sqlglot
818        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
819        >>> pushdown_cte_alias_columns(expression).sql()
820        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
821
822    Args:
823        expression: Expression to pushdown.
824
825    Returns:
826        The expression with the CTE aliases pushed down into the projection.
827    """
828    for cte in expression.find_all(exp.CTE):
829        if cte.alias_column_names:
830            new_expressions = []
831            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
832                if isinstance(projection, exp.Alias):
833                    projection.set("alias", _alias)
834                else:
835                    projection = alias(projection, alias=_alias)
836                new_expressions.append(projection)
837            cte.this.set("expressions", new_expressions)
838
839    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
 842class Resolver:
 843    """
 844    Helper for resolving columns.
 845
 846    This is a class so we can lazily load some things and easily share them across functions.
 847    """
 848
 849    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 850        self.scope = scope
 851        self.schema = schema
 852        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 853        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 854        self._all_columns: t.Optional[t.Set[str]] = None
 855        self._infer_schema = infer_schema
 856        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 857
 858    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 859        """
 860        Get the table for a column name.
 861
 862        Args:
 863            column_name: The column name to find the table for.
 864        Returns:
 865            The table name if it can be found/inferred.
 866        """
 867        if self._unambiguous_columns is None:
 868            self._unambiguous_columns = self._get_unambiguous_columns(
 869                self._get_all_source_columns()
 870            )
 871
 872        table_name = self._unambiguous_columns.get(column_name)
 873
 874        if not table_name and self._infer_schema:
 875            sources_without_schema = tuple(
 876                source
 877                for source, columns in self._get_all_source_columns().items()
 878                if not columns or "*" in columns
 879            )
 880            if len(sources_without_schema) == 1:
 881                table_name = sources_without_schema[0]
 882
 883        if table_name not in self.scope.selected_sources:
 884            return exp.to_identifier(table_name)
 885
 886        node, _ = self.scope.selected_sources.get(table_name)
 887
 888        if isinstance(node, exp.Query):
 889            while node and node.alias != table_name:
 890                node = node.parent
 891
 892        node_alias = node.args.get("alias")
 893        if node_alias:
 894            return exp.to_identifier(node_alias.this)
 895
 896        return exp.to_identifier(table_name)
 897
 898    @property
 899    def all_columns(self) -> t.Set[str]:
 900        """All available columns of all sources in this scope"""
 901        if self._all_columns is None:
 902            self._all_columns = {
 903                column for columns in self._get_all_source_columns().values() for column in columns
 904            }
 905        return self._all_columns
 906
 907    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 908        """Resolve the source columns for a given source `name`."""
 909        cache_key = (name, only_visible)
 910        if cache_key not in self._get_source_columns_cache:
 911            if name not in self.scope.sources:
 912                raise OptimizeError(f"Unknown table: {name}")
 913
 914            source = self.scope.sources[name]
 915
 916            if isinstance(source, exp.Table):
 917                columns = self.schema.column_names(source, only_visible)
 918            elif isinstance(source, Scope) and isinstance(
 919                source.expression, (exp.Values, exp.Unnest)
 920            ):
 921                columns = source.expression.named_selects
 922
 923                # in bigquery, unnest structs are automatically scoped as tables, so you can
 924                # directly select a struct field in a query.
 925                # this handles the case where the unnest is statically defined.
 926                if self.schema.dialect == "bigquery":
 927                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 928                        for k in source.expression.type.expressions:  # type: ignore
 929                            columns.append(k.name)
 930            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
 931                set_op = source.expression
 932
 933                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
 934                on_column_list = set_op.args.get("on")
 935
 936                if on_column_list:
 937                    # The resulting columns are the columns in the ON clause:
 938                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
 939                    columns = [col.name for col in on_column_list]
 940                elif set_op.side or set_op.kind:
 941                    side = set_op.side
 942                    kind = set_op.kind
 943
 944                    left = set_op.left.named_selects
 945                    right = set_op.right.named_selects
 946
 947                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
 948                    if side == "LEFT":
 949                        columns = left
 950                    elif side == "FULL":
 951                        columns = list(dict.fromkeys(left + right))
 952                    elif kind == "INNER":
 953                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
 954                else:
 955                    columns = set_op.named_selects
 956            else:
 957                select = seq_get(source.expression.selects, 0)
 958
 959                if isinstance(select, exp.QueryTransform):
 960                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
 961                    schema = select.args.get("schema")
 962                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
 963                else:
 964                    columns = source.expression.named_selects
 965
 966            node, _ = self.scope.selected_sources.get(name) or (None, None)
 967            if isinstance(node, Scope):
 968                column_aliases = node.expression.alias_column_names
 969            elif isinstance(node, exp.Expression):
 970                column_aliases = node.alias_column_names
 971            else:
 972                column_aliases = []
 973
 974            if column_aliases:
 975                # If the source's columns are aliased, their aliases shadow the corresponding column names.
 976                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
 977                columns = [
 978                    alias or name
 979                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
 980                ]
 981
 982            self._get_source_columns_cache[cache_key] = columns
 983
 984        return self._get_source_columns_cache[cache_key]
 985
 986    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
 987        if self._source_columns is None:
 988            self._source_columns = {
 989                source_name: self.get_source_columns(source_name)
 990                for source_name, source in itertools.chain(
 991                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
 992                )
 993            }
 994        return self._source_columns
 995
 996    def _get_unambiguous_columns(
 997        self, source_columns: t.Dict[str, t.Sequence[str]]
 998    ) -> t.Mapping[str, str]:
 999        """
1000        Find all the unambiguous columns in sources.
1001
1002        Args:
1003            source_columns: Mapping of names to source columns.
1004
1005        Returns:
1006            Mapping of column name to source name.
1007        """
1008        if not source_columns:
1009            return {}
1010
1011        source_columns_pairs = list(source_columns.items())
1012
1013        first_table, first_columns = source_columns_pairs[0]
1014
1015        if len(source_columns_pairs) == 1:
1016            # Performance optimization - avoid copying first_columns if there is only one table.
1017            return SingleValuedMapping(first_columns, first_table)
1018
1019        unambiguous_columns = {col: first_table for col in first_columns}
1020        all_columns = set(unambiguous_columns)
1021
1022        for table, columns in source_columns_pairs[1:]:
1023            unique = set(columns)
1024            ambiguous = all_columns.intersection(unique)
1025            all_columns.update(columns)
1026
1027            for column in ambiguous:
1028                unambiguous_columns.pop(column, None)
1029            for column in unique.difference(ambiguous):
1030                unambiguous_columns[column] = table
1031
1032        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
849    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
850        self.scope = scope
851        self.schema = schema
852        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
853        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
854        self._all_columns: t.Optional[t.Set[str]] = None
855        self._infer_schema = infer_schema
856        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
858    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
859        """
860        Get the table for a column name.
861
862        Args:
863            column_name: The column name to find the table for.
864        Returns:
865            The table name if it can be found/inferred.
866        """
867        if self._unambiguous_columns is None:
868            self._unambiguous_columns = self._get_unambiguous_columns(
869                self._get_all_source_columns()
870            )
871
872        table_name = self._unambiguous_columns.get(column_name)
873
874        if not table_name and self._infer_schema:
875            sources_without_schema = tuple(
876                source
877                for source, columns in self._get_all_source_columns().items()
878                if not columns or "*" in columns
879            )
880            if len(sources_without_schema) == 1:
881                table_name = sources_without_schema[0]
882
883        if table_name not in self.scope.selected_sources:
884            return exp.to_identifier(table_name)
885
886        node, _ = self.scope.selected_sources.get(table_name)
887
888        if isinstance(node, exp.Query):
889            while node and node.alias != table_name:
890                node = node.parent
891
892        node_alias = node.args.get("alias")
893        if node_alias:
894            return exp.to_identifier(node_alias.this)
895
896        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
898    @property
899    def all_columns(self) -> t.Set[str]:
900        """All available columns of all sources in this scope"""
901        if self._all_columns is None:
902            self._all_columns = {
903                column for columns in self._get_all_source_columns().values() for column in columns
904            }
905        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
907    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
908        """Resolve the source columns for a given source `name`."""
909        cache_key = (name, only_visible)
910        if cache_key not in self._get_source_columns_cache:
911            if name not in self.scope.sources:
912                raise OptimizeError(f"Unknown table: {name}")
913
914            source = self.scope.sources[name]
915
916            if isinstance(source, exp.Table):
917                columns = self.schema.column_names(source, only_visible)
918            elif isinstance(source, Scope) and isinstance(
919                source.expression, (exp.Values, exp.Unnest)
920            ):
921                columns = source.expression.named_selects
922
923                # in bigquery, unnest structs are automatically scoped as tables, so you can
924                # directly select a struct field in a query.
925                # this handles the case where the unnest is statically defined.
926                if self.schema.dialect == "bigquery":
927                    if source.expression.is_type(exp.DataType.Type.STRUCT):
928                        for k in source.expression.type.expressions:  # type: ignore
929                            columns.append(k.name)
930            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
931                set_op = source.expression
932
933                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
934                on_column_list = set_op.args.get("on")
935
936                if on_column_list:
937                    # The resulting columns are the columns in the ON clause:
938                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
939                    columns = [col.name for col in on_column_list]
940                elif set_op.side or set_op.kind:
941                    side = set_op.side
942                    kind = set_op.kind
943
944                    left = set_op.left.named_selects
945                    right = set_op.right.named_selects
946
947                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
948                    if side == "LEFT":
949                        columns = left
950                    elif side == "FULL":
951                        columns = list(dict.fromkeys(left + right))
952                    elif kind == "INNER":
953                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
954                else:
955                    columns = set_op.named_selects
956            else:
957                select = seq_get(source.expression.selects, 0)
958
959                if isinstance(select, exp.QueryTransform):
960                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
961                    schema = select.args.get("schema")
962                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
963                else:
964                    columns = source.expression.named_selects
965
966            node, _ = self.scope.selected_sources.get(name) or (None, None)
967            if isinstance(node, Scope):
968                column_aliases = node.expression.alias_column_names
969            elif isinstance(node, exp.Expression):
970                column_aliases = node.alias_column_names
971            else:
972                column_aliases = []
973
974            if column_aliases:
975                # If the source's columns are aliased, their aliases shadow the corresponding column names.
976                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
977                columns = [
978                    alias or name
979                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
980                ]
981
982            self._get_source_columns_cache[cache_key] = columns
983
984        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.