Edit on GitHub

sqlglot.optimizer.qualify_columns

   1from __future__ import annotations
   2
   3import itertools
   4import typing as t
   5
   6from sqlglot import alias, exp
   7from sqlglot.dialects.dialect import Dialect, DialectType
   8from sqlglot.errors import OptimizeError
   9from sqlglot.helper import seq_get, SingleValuedMapping
  10from sqlglot.optimizer.annotate_types import TypeAnnotator
  11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
  12from sqlglot.optimizer.simplify import simplify_parens
  13from sqlglot.schema import Schema, ensure_schema
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot._typing import E
  17
  18
  19def qualify_columns(
  20    expression: exp.Expression,
  21    schema: t.Dict | Schema,
  22    expand_alias_refs: bool = True,
  23    expand_stars: bool = True,
  24    infer_schema: t.Optional[bool] = None,
  25    allow_partial_qualification: bool = False,
  26    dialect: DialectType = None,
  27) -> exp.Expression:
  28    """
  29    Rewrite sqlglot AST to have fully qualified columns.
  30
  31    Example:
  32        >>> import sqlglot
  33        >>> schema = {"tbl": {"col": "INT"}}
  34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
  35        >>> qualify_columns(expression, schema).sql()
  36        'SELECT tbl.col AS col FROM tbl'
  37
  38    Args:
  39        expression: Expression to qualify.
  40        schema: Database schema.
  41        expand_alias_refs: Whether to expand references to aliases.
  42        expand_stars: Whether to expand star queries. This is a necessary step
  43            for most of the optimizer's rules to work; do not set to False unless you
  44            know what you're doing!
  45        infer_schema: Whether to infer the schema if missing.
  46        allow_partial_qualification: Whether to allow partial qualification.
  47
  48    Returns:
  49        The qualified expression.
  50
  51    Notes:
  52        - Currently only handles a single PIVOT or UNPIVOT operator
  53    """
  54    schema = ensure_schema(schema, dialect=dialect)
  55    annotator = TypeAnnotator(schema)
  56    infer_schema = schema.empty if infer_schema is None else infer_schema
  57    dialect = Dialect.get_or_raise(schema.dialect)
  58    pseudocolumns = dialect.PSEUDOCOLUMNS
  59    bigquery = dialect == "bigquery"
  60
  61    for scope in traverse_scope(expression):
  62        scope_expression = scope.expression
  63        is_select = isinstance(scope_expression, exp.Select)
  64
  65        if is_select and scope_expression.args.get("connect"):
  66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
  67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
  68            scope_expression.transform(
  69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
  70                copy=False,
  71            )
  72            scope.clear_cache()
  73
  74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
  75        _pop_table_column_aliases(scope.ctes)
  76        _pop_table_column_aliases(scope.derived_tables)
  77        using_column_tables = _expand_using(scope, resolver)
  78
  79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
  80            _expand_alias_refs(
  81                scope,
  82                resolver,
  83                dialect,
  84                expand_only_groupby=bigquery,
  85            )
  86
  87        _convert_columns_to_dots(scope, resolver)
  88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
  89
  90        if not schema.empty and expand_alias_refs:
  91            _expand_alias_refs(scope, resolver, dialect)
  92
  93        if is_select:
  94            if expand_stars:
  95                _expand_stars(
  96                    scope,
  97                    resolver,
  98                    using_column_tables,
  99                    pseudocolumns,
 100                    annotator,
 101                )
 102            qualify_outputs(scope)
 103
 104        _expand_group_by(scope, dialect)
 105
 106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
 107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
 108        _expand_order_by_and_distinct_on(scope, resolver)
 109
 110        if bigquery:
 111            annotator.annotate_scope(scope)
 112
 113    return expression
 114
 115
 116def validate_qualify_columns(expression: E) -> E:
 117    """Raise an `OptimizeError` if any columns aren't qualified"""
 118    all_unqualified_columns = []
 119    for scope in traverse_scope(expression):
 120        if isinstance(scope.expression, exp.Select):
 121            unqualified_columns = scope.unqualified_columns
 122
 123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 124                column = scope.external_columns[0]
 125                for_table = f" for table: '{column.table}'" if column.table else ""
 126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 127
 128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 131                # this list here to ensure those in the former category will be excluded.
 132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 134
 135            all_unqualified_columns.extend(unqualified_columns)
 136
 137    if all_unqualified_columns:
 138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
 139
 140    return expression
 141
 142
 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
 144    name_columns = [
 145        field.this
 146        for field in unpivot.fields
 147        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
 148    ]
 149    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
 150
 151    return itertools.chain(name_columns, value_columns)
 152
 153
 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 155    """
 156    Remove table column aliases.
 157
 158    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 159    """
 160    for derived_table in derived_tables:
 161        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
 162            continue
 163        table_alias = derived_table.args.get("alias")
 164        if table_alias:
 165            table_alias.args.pop("columns", None)
 166
 167
 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 169    columns = {}
 170
 171    def _update_source_columns(source_name: str) -> None:
 172        for column_name in resolver.get_source_columns(source_name):
 173            if column_name not in columns:
 174                columns[column_name] = source_name
 175
 176    joins = list(scope.find_all(exp.Join))
 177    names = {join.alias_or_name for join in joins}
 178    ordered = [key for key in scope.selected_sources if key not in names]
 179
 180    if names and not ordered:
 181        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
 182
 183    # Mapping of automatically joined column names to an ordered set of source names (dict).
 184    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
 185
 186    for source_name in ordered:
 187        _update_source_columns(source_name)
 188
 189    for i, join in enumerate(joins):
 190        source_table = ordered[-1]
 191        if source_table:
 192            _update_source_columns(source_table)
 193
 194        join_table = join.alias_or_name
 195        ordered.append(join_table)
 196
 197        using = join.args.get("using")
 198        if not using:
 199            continue
 200
 201        join_columns = resolver.get_source_columns(join_table)
 202        conditions = []
 203        using_identifier_count = len(using)
 204        is_semi_or_anti_join = join.is_semi_or_anti_join
 205
 206        for identifier in using:
 207            identifier = identifier.name
 208            table = columns.get(identifier)
 209
 210            if not table or identifier not in join_columns:
 211                if (columns and "*" not in columns) and join_columns:
 212                    raise OptimizeError(f"Cannot automatically join: {identifier}")
 213
 214            table = table or source_table
 215
 216            if i == 0 or using_identifier_count == 1:
 217                lhs: exp.Expression = exp.column(identifier, table=table)
 218            else:
 219                coalesce_columns = [
 220                    exp.column(identifier, table=t)
 221                    for t in ordered[:-1]
 222                    if identifier in resolver.get_source_columns(t)
 223                ]
 224                if len(coalesce_columns) > 1:
 225                    lhs = exp.func("coalesce", *coalesce_columns)
 226                else:
 227                    lhs = exp.column(identifier, table=table)
 228
 229            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
 230
 231            # Set all values in the dict to None, because we only care about the key ordering
 232            tables = column_tables.setdefault(identifier, {})
 233
 234            # Do not update the dict if this was a SEMI/ANTI join in
 235            # order to avoid generating COALESCE columns for this join pair
 236            if not is_semi_or_anti_join:
 237                if table not in tables:
 238                    tables[table] = None
 239                if join_table not in tables:
 240                    tables[join_table] = None
 241
 242        join.args.pop("using")
 243        join.set("on", exp.and_(*conditions, copy=False))
 244
 245    if column_tables:
 246        for column in scope.columns:
 247            if not column.table and column.name in column_tables:
 248                tables = column_tables[column.name]
 249                coalesce_args = [exp.column(column.name, table=table) for table in tables]
 250                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
 251
 252                if isinstance(column.parent, exp.Select):
 253                    # Ensure the USING column keeps its name if it's projected
 254                    replacement = alias(replacement, alias=column.name, copy=False)
 255                elif isinstance(column.parent, exp.Struct):
 256                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
 257                    replacement = exp.PropertyEQ(
 258                        this=exp.to_identifier(column.name), expression=replacement
 259                    )
 260
 261                scope.replace(column, replacement)
 262
 263    return column_tables
 264
 265
 266def _expand_alias_refs(
 267    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
 268) -> None:
 269    """
 270    Expand references to aliases.
 271    Example:
 272        SELECT y.foo AS bar, bar * 2 AS baz FROM y
 273     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
 274    """
 275    expression = scope.expression
 276
 277    if not isinstance(expression, exp.Select) or dialect == "oracle":
 278        return
 279
 280    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
 281    projections = {s.alias_or_name for s in expression.selects}
 282
 283    def replace_columns(
 284        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
 285    ) -> None:
 286        is_group_by = isinstance(node, exp.Group)
 287        is_having = isinstance(node, exp.Having)
 288        if not node or (expand_only_groupby and not is_group_by):
 289            return
 290
 291        for column in walk_in_scope(node, prune=lambda node: node.is_star):
 292            if not isinstance(column, exp.Column):
 293                continue
 294
 295            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
 296            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
 297            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
 298            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
 299            if expand_only_groupby and is_group_by and column.parent is not node:
 300                continue
 301
 302            skip_replace = False
 303            table = resolver.get_table(column.name) if resolve_table and not column.table else None
 304            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
 305
 306            if alias_expr:
 307                skip_replace = bool(
 308                    alias_expr.find(exp.AggFunc)
 309                    and column.find_ancestor(exp.AggFunc)
 310                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
 311                )
 312
 313                # BigQuery's having clause gets confused if an alias matches a source.
 314                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
 315                # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table
 316                if is_having and dialect == "bigquery":
 317                    skip_replace = skip_replace or any(
 318                        node.parts[0].name in projections
 319                        for node in alias_expr.find_all(exp.Column)
 320                    )
 321
 322            if table and (not alias_expr or skip_replace):
 323                column.set("table", table)
 324            elif not column.table and alias_expr and not skip_replace:
 325                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
 326                    if literal_index:
 327                        column.replace(exp.Literal.number(i))
 328                else:
 329                    column = column.replace(exp.paren(alias_expr))
 330                    simplified = simplify_parens(column)
 331                    if simplified is not column:
 332                        column.replace(simplified)
 333
 334    for i, projection in enumerate(expression.selects):
 335        replace_columns(projection)
 336        if isinstance(projection, exp.Alias):
 337            alias_to_expression[projection.alias] = (projection.this, i + 1)
 338
 339    parent_scope = scope
 340    while parent_scope.is_union:
 341        parent_scope = parent_scope.parent
 342
 343    # We shouldn't expand aliases if they match the recursive CTE's columns
 344    if parent_scope.is_cte:
 345        cte = parent_scope.expression.parent
 346        if cte.find_ancestor(exp.With).recursive:
 347            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
 348                alias_to_expression.pop(recursive_cte_column.output_name, None)
 349
 350    replace_columns(expression.args.get("where"))
 351    replace_columns(expression.args.get("group"), literal_index=True)
 352    replace_columns(expression.args.get("having"), resolve_table=True)
 353    replace_columns(expression.args.get("qualify"), resolve_table=True)
 354
 355    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
 356    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
 357    if dialect == "snowflake":
 358        for join in expression.args.get("joins") or []:
 359            replace_columns(join)
 360
 361    scope.clear_cache()
 362
 363
 364def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
 365    expression = scope.expression
 366    group = expression.args.get("group")
 367    if not group:
 368        return
 369
 370    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
 371    expression.set("group", group)
 372
 373
 374def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
 375    for modifier_key in ("order", "distinct"):
 376        modifier = scope.expression.args.get(modifier_key)
 377        if isinstance(modifier, exp.Distinct):
 378            modifier = modifier.args.get("on")
 379
 380        if not isinstance(modifier, exp.Expression):
 381            continue
 382
 383        modifier_expressions = modifier.expressions
 384        if modifier_key == "order":
 385            modifier_expressions = [ordered.this for ordered in modifier_expressions]
 386
 387        for original, expanded in zip(
 388            modifier_expressions,
 389            _expand_positional_references(
 390                scope, modifier_expressions, resolver.schema.dialect, alias=True
 391            ),
 392        ):
 393            for agg in original.find_all(exp.AggFunc):
 394                for col in agg.find_all(exp.Column):
 395                    if not col.table:
 396                        col.set("table", resolver.get_table(col.name))
 397
 398            original.replace(expanded)
 399
 400        if scope.expression.args.get("group"):
 401            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
 402
 403            for expression in modifier_expressions:
 404                expression.replace(
 405                    exp.to_identifier(_select_by_pos(scope, expression).alias)
 406                    if expression.is_int
 407                    else selects.get(expression, expression)
 408                )
 409
 410
 411def _expand_positional_references(
 412    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
 413) -> t.List[exp.Expression]:
 414    new_nodes: t.List[exp.Expression] = []
 415    ambiguous_projections = None
 416
 417    for node in expressions:
 418        if node.is_int:
 419            select = _select_by_pos(scope, t.cast(exp.Literal, node))
 420
 421            if alias:
 422                new_nodes.append(exp.column(select.args["alias"].copy()))
 423            else:
 424                select = select.this
 425
 426                if dialect == "bigquery":
 427                    if ambiguous_projections is None:
 428                        # When a projection name is also a source name and it is referenced in the
 429                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
 430                        ambiguous_projections = {
 431                            s.alias_or_name
 432                            for s in scope.expression.selects
 433                            if s.alias_or_name in scope.selected_sources
 434                        }
 435
 436                    ambiguous = any(
 437                        column.parts[0].name in ambiguous_projections
 438                        for column in select.find_all(exp.Column)
 439                    )
 440                else:
 441                    ambiguous = False
 442
 443                if (
 444                    isinstance(select, exp.CONSTANTS)
 445                    or select.find(exp.Explode, exp.Unnest)
 446                    or ambiguous
 447                ):
 448                    new_nodes.append(node)
 449                else:
 450                    new_nodes.append(select.copy())
 451        else:
 452            new_nodes.append(node)
 453
 454    return new_nodes
 455
 456
 457def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
 458    try:
 459        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
 460    except IndexError:
 461        raise OptimizeError(f"Unknown output column: {node.name}")
 462
 463
 464def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
 465    """
 466    Converts `Column` instances that represent struct field lookup into chained `Dots`.
 467
 468    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
 469    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
 470    """
 471    converted = False
 472    for column in itertools.chain(scope.columns, scope.stars):
 473        if isinstance(column, exp.Dot):
 474            continue
 475
 476        column_table: t.Optional[str | exp.Identifier] = column.table
 477        if (
 478            column_table
 479            and column_table not in scope.sources
 480            and (
 481                not scope.parent
 482                or column_table not in scope.parent.sources
 483                or not scope.is_correlated_subquery
 484            )
 485        ):
 486            root, *parts = column.parts
 487
 488            if root.name in scope.sources:
 489                # The struct is already qualified, but we still need to change the AST
 490                column_table = root
 491                root, *parts = parts
 492            else:
 493                column_table = resolver.get_table(root.name)
 494
 495            if column_table:
 496                converted = True
 497                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
 498
 499    if converted:
 500        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
 501        # a `for column in scope.columns` iteration, even though they shouldn't be
 502        scope.clear_cache()
 503
 504
 505def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
 506    """Disambiguate columns, ensuring each column specifies a source"""
 507    for column in scope.columns:
 508        column_table = column.table
 509        column_name = column.name
 510
 511        if column_table and column_table in scope.sources:
 512            source_columns = resolver.get_source_columns(column_table)
 513            if (
 514                not allow_partial_qualification
 515                and source_columns
 516                and column_name not in source_columns
 517                and "*" not in source_columns
 518            ):
 519                raise OptimizeError(f"Unknown column: {column_name}")
 520
 521        if not column_table:
 522            if scope.pivots and not column.find_ancestor(exp.Pivot):
 523                # If the column is under the Pivot expression, we need to qualify it
 524                # using the name of the pivoted source instead of the pivot's alias
 525                column.set("table", exp.to_identifier(scope.pivots[0].alias))
 526                continue
 527
 528            # column_table can be a '' because bigquery unnest has no table alias
 529            column_table = resolver.get_table(column_name)
 530            if column_table:
 531                column.set("table", column_table)
 532
 533    for pivot in scope.pivots:
 534        for column in pivot.find_all(exp.Column):
 535            if not column.table and column.name in resolver.all_columns:
 536                column_table = resolver.get_table(column.name)
 537                if column_table:
 538                    column.set("table", column_table)
 539
 540
 541def _expand_struct_stars(
 542    expression: exp.Dot,
 543) -> t.List[exp.Alias]:
 544    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
 545
 546    dot_column = t.cast(exp.Column, expression.find(exp.Column))
 547    if not dot_column.is_type(exp.DataType.Type.STRUCT):
 548        return []
 549
 550    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
 551    dot_column = dot_column.copy()
 552    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
 553
 554    # First part is the table name and last part is the star so they can be dropped
 555    dot_parts = expression.parts[1:-1]
 556
 557    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
 558    for part in dot_parts[1:]:
 559        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 560            # Unable to expand star unless all fields are named
 561            if not isinstance(field.this, exp.Identifier):
 562                return []
 563
 564            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
 565                starting_struct = field
 566                break
 567        else:
 568            # There is no matching field in the struct
 569            return []
 570
 571    taken_names = set()
 572    new_selections = []
 573
 574    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 575        name = field.name
 576
 577        # Ambiguous or anonymous fields can't be expanded
 578        if name in taken_names or not isinstance(field.this, exp.Identifier):
 579            return []
 580
 581        taken_names.add(name)
 582
 583        this = field.this.copy()
 584        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
 585        new_column = exp.column(
 586            t.cast(exp.Identifier, root),
 587            table=dot_column.args.get("table"),
 588            fields=t.cast(t.List[exp.Identifier], parts),
 589        )
 590        new_selections.append(alias(new_column, this, copy=False))
 591
 592    return new_selections
 593
 594
 595def _expand_stars(
 596    scope: Scope,
 597    resolver: Resolver,
 598    using_column_tables: t.Dict[str, t.Any],
 599    pseudocolumns: t.Set[str],
 600    annotator: TypeAnnotator,
 601) -> None:
 602    """Expand stars to lists of column selections"""
 603
 604    new_selections: t.List[exp.Expression] = []
 605    except_columns: t.Dict[int, t.Set[str]] = {}
 606    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
 607    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
 608
 609    coalesced_columns = set()
 610    dialect = resolver.schema.dialect
 611
 612    pivot_output_columns = None
 613    pivot_exclude_columns: t.Set[str] = set()
 614
 615    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
 616    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
 617        if pivot.unpivot:
 618            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
 619
 620            for field in pivot.fields:
 621                if isinstance(field, exp.In):
 622                    pivot_exclude_columns.update(
 623                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
 624                    )
 625
 626        else:
 627            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
 628
 629            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
 630            if not pivot_output_columns:
 631                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
 632
 633    is_bigquery = dialect == "bigquery"
 634    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
 635        # Found struct expansion, annotate scope ahead of time
 636        annotator.annotate_scope(scope)
 637
 638    for expression in scope.expression.selects:
 639        tables = []
 640        if isinstance(expression, exp.Star):
 641            tables.extend(scope.selected_sources)
 642            _add_except_columns(expression, tables, except_columns)
 643            _add_replace_columns(expression, tables, replace_columns)
 644            _add_rename_columns(expression, tables, rename_columns)
 645        elif expression.is_star:
 646            if not isinstance(expression, exp.Dot):
 647                tables.append(expression.table)
 648                _add_except_columns(expression.this, tables, except_columns)
 649                _add_replace_columns(expression.this, tables, replace_columns)
 650                _add_rename_columns(expression.this, tables, rename_columns)
 651            elif is_bigquery:
 652                struct_fields = _expand_struct_stars(expression)
 653                if struct_fields:
 654                    new_selections.extend(struct_fields)
 655                    continue
 656
 657        if not tables:
 658            new_selections.append(expression)
 659            continue
 660
 661        for table in tables:
 662            if table not in scope.sources:
 663                raise OptimizeError(f"Unknown table: {table}")
 664
 665            columns = resolver.get_source_columns(table, only_visible=True)
 666            columns = columns or scope.outer_columns
 667
 668            if pseudocolumns:
 669                columns = [name for name in columns if name.upper() not in pseudocolumns]
 670
 671            if not columns or "*" in columns:
 672                return
 673
 674            table_id = id(table)
 675            columns_to_exclude = except_columns.get(table_id) or set()
 676            renamed_columns = rename_columns.get(table_id, {})
 677            replaced_columns = replace_columns.get(table_id, {})
 678
 679            if pivot:
 680                if pivot_output_columns and pivot_exclude_columns:
 681                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
 682                    pivot_columns.extend(pivot_output_columns)
 683                else:
 684                    pivot_columns = pivot.alias_column_names
 685
 686                if pivot_columns:
 687                    new_selections.extend(
 688                        alias(exp.column(name, table=pivot.alias), name, copy=False)
 689                        for name in pivot_columns
 690                        if name not in columns_to_exclude
 691                    )
 692                    continue
 693
 694            for name in columns:
 695                if name in columns_to_exclude or name in coalesced_columns:
 696                    continue
 697                if name in using_column_tables and table in using_column_tables[name]:
 698                    coalesced_columns.add(name)
 699                    tables = using_column_tables[name]
 700                    coalesce_args = [exp.column(name, table=table) for table in tables]
 701
 702                    new_selections.append(
 703                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
 704                    )
 705                else:
 706                    alias_ = renamed_columns.get(name, name)
 707                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
 708                    new_selections.append(
 709                        alias(selection_expr, alias_, copy=False)
 710                        if alias_ != name
 711                        else selection_expr
 712                    )
 713
 714    # Ensures we don't overwrite the initial selections with an empty list
 715    if new_selections and isinstance(scope.expression, exp.Select):
 716        scope.expression.set("expressions", new_selections)
 717
 718
 719def _add_except_columns(
 720    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
 721) -> None:
 722    except_ = expression.args.get("except")
 723
 724    if not except_:
 725        return
 726
 727    columns = {e.name for e in except_}
 728
 729    for table in tables:
 730        except_columns[id(table)] = columns
 731
 732
 733def _add_rename_columns(
 734    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
 735) -> None:
 736    rename = expression.args.get("rename")
 737
 738    if not rename:
 739        return
 740
 741    columns = {e.this.name: e.alias for e in rename}
 742
 743    for table in tables:
 744        rename_columns[id(table)] = columns
 745
 746
 747def _add_replace_columns(
 748    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
 749) -> None:
 750    replace = expression.args.get("replace")
 751
 752    if not replace:
 753        return
 754
 755    columns = {e.alias: e for e in replace}
 756
 757    for table in tables:
 758        replace_columns[id(table)] = columns
 759
 760
 761def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
 762    """Ensure all output columns are aliased"""
 763    if isinstance(scope_or_expression, exp.Expression):
 764        scope = build_scope(scope_or_expression)
 765        if not isinstance(scope, Scope):
 766            return
 767    else:
 768        scope = scope_or_expression
 769
 770    new_selections = []
 771    for i, (selection, aliased_column) in enumerate(
 772        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
 773    ):
 774        if selection is None or isinstance(selection, exp.QueryTransform):
 775            break
 776
 777        if isinstance(selection, exp.Subquery):
 778            if not selection.output_name:
 779                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
 780        elif not isinstance(selection, exp.Alias) and not selection.is_star:
 781            selection = alias(
 782                selection,
 783                alias=selection.output_name or f"_col_{i}",
 784                copy=False,
 785            )
 786        if aliased_column:
 787            selection.set("alias", exp.to_identifier(aliased_column))
 788
 789        new_selections.append(selection)
 790
 791    if new_selections and isinstance(scope.expression, exp.Select):
 792        scope.expression.set("expressions", new_selections)
 793
 794
 795def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
 796    """Makes sure all identifiers that need to be quoted are quoted."""
 797    return expression.transform(
 798        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
 799    )  # type: ignore
 800
 801
 802def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
 803    """
 804    Pushes down the CTE alias columns into the projection,
 805
 806    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
 807
 808    Example:
 809        >>> import sqlglot
 810        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
 811        >>> pushdown_cte_alias_columns(expression).sql()
 812        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
 813
 814    Args:
 815        expression: Expression to pushdown.
 816
 817    Returns:
 818        The expression with the CTE aliases pushed down into the projection.
 819    """
 820    for cte in expression.find_all(exp.CTE):
 821        if cte.alias_column_names:
 822            new_expressions = []
 823            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
 824                if isinstance(projection, exp.Alias):
 825                    projection.set("alias", _alias)
 826                else:
 827                    projection = alias(projection, alias=_alias)
 828                new_expressions.append(projection)
 829            cte.this.set("expressions", new_expressions)
 830
 831    return expression
 832
 833
 834class Resolver:
 835    """
 836    Helper for resolving columns.
 837
 838    This is a class so we can lazily load some things and easily share them across functions.
 839    """
 840
 841    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 842        self.scope = scope
 843        self.schema = schema
 844        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 845        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 846        self._all_columns: t.Optional[t.Set[str]] = None
 847        self._infer_schema = infer_schema
 848        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 849
 850    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 851        """
 852        Get the table for a column name.
 853
 854        Args:
 855            column_name: The column name to find the table for.
 856        Returns:
 857            The table name if it can be found/inferred.
 858        """
 859        if self._unambiguous_columns is None:
 860            self._unambiguous_columns = self._get_unambiguous_columns(
 861                self._get_all_source_columns()
 862            )
 863
 864        table_name = self._unambiguous_columns.get(column_name)
 865
 866        if not table_name and self._infer_schema:
 867            sources_without_schema = tuple(
 868                source
 869                for source, columns in self._get_all_source_columns().items()
 870                if not columns or "*" in columns
 871            )
 872            if len(sources_without_schema) == 1:
 873                table_name = sources_without_schema[0]
 874
 875        if table_name not in self.scope.selected_sources:
 876            return exp.to_identifier(table_name)
 877
 878        node, _ = self.scope.selected_sources.get(table_name)
 879
 880        if isinstance(node, exp.Query):
 881            while node and node.alias != table_name:
 882                node = node.parent
 883
 884        node_alias = node.args.get("alias")
 885        if node_alias:
 886            return exp.to_identifier(node_alias.this)
 887
 888        return exp.to_identifier(table_name)
 889
 890    @property
 891    def all_columns(self) -> t.Set[str]:
 892        """All available columns of all sources in this scope"""
 893        if self._all_columns is None:
 894            self._all_columns = {
 895                column for columns in self._get_all_source_columns().values() for column in columns
 896            }
 897        return self._all_columns
 898
 899    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 900        """Resolve the source columns for a given source `name`."""
 901        cache_key = (name, only_visible)
 902        if cache_key not in self._get_source_columns_cache:
 903            if name not in self.scope.sources:
 904                raise OptimizeError(f"Unknown table: {name}")
 905
 906            source = self.scope.sources[name]
 907
 908            if isinstance(source, exp.Table):
 909                columns = self.schema.column_names(source, only_visible)
 910            elif isinstance(source, Scope) and isinstance(
 911                source.expression, (exp.Values, exp.Unnest)
 912            ):
 913                columns = source.expression.named_selects
 914
 915                # in bigquery, unnest structs are automatically scoped as tables, so you can
 916                # directly select a struct field in a query.
 917                # this handles the case where the unnest is statically defined.
 918                if self.schema.dialect == "bigquery":
 919                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 920                        for k in source.expression.type.expressions:  # type: ignore
 921                            columns.append(k.name)
 922            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
 923                set_op = source.expression
 924
 925                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
 926                on_column_list = set_op.args.get("on")
 927
 928                if on_column_list:
 929                    # The resulting columns are the columns in the ON clause:
 930                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
 931                    columns = [col.name for col in on_column_list]
 932                elif set_op.side or set_op.kind:
 933                    side = set_op.side
 934                    kind = set_op.kind
 935
 936                    left = set_op.left.named_selects
 937                    right = set_op.right.named_selects
 938
 939                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
 940                    if side == "LEFT":
 941                        columns = left
 942                    elif side == "FULL":
 943                        columns = list(dict.fromkeys(left + right))
 944                    elif kind == "INNER":
 945                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
 946                else:
 947                    columns = set_op.named_selects
 948            else:
 949                select = seq_get(source.expression.selects, 0)
 950
 951                if isinstance(select, exp.QueryTransform):
 952                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
 953                    schema = select.args.get("schema")
 954                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
 955                else:
 956                    columns = source.expression.named_selects
 957
 958            node, _ = self.scope.selected_sources.get(name) or (None, None)
 959            if isinstance(node, Scope):
 960                column_aliases = node.expression.alias_column_names
 961            elif isinstance(node, exp.Expression):
 962                column_aliases = node.alias_column_names
 963            else:
 964                column_aliases = []
 965
 966            if column_aliases:
 967                # If the source's columns are aliased, their aliases shadow the corresponding column names.
 968                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
 969                columns = [
 970                    alias or name
 971                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
 972                ]
 973
 974            self._get_source_columns_cache[cache_key] = columns
 975
 976        return self._get_source_columns_cache[cache_key]
 977
 978    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
 979        if self._source_columns is None:
 980            self._source_columns = {
 981                source_name: self.get_source_columns(source_name)
 982                for source_name, source in itertools.chain(
 983                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
 984                )
 985            }
 986        return self._source_columns
 987
 988    def _get_unambiguous_columns(
 989        self, source_columns: t.Dict[str, t.Sequence[str]]
 990    ) -> t.Mapping[str, str]:
 991        """
 992        Find all the unambiguous columns in sources.
 993
 994        Args:
 995            source_columns: Mapping of names to source columns.
 996
 997        Returns:
 998            Mapping of column name to source name.
 999        """
1000        if not source_columns:
1001            return {}
1002
1003        source_columns_pairs = list(source_columns.items())
1004
1005        first_table, first_columns = source_columns_pairs[0]
1006
1007        if len(source_columns_pairs) == 1:
1008            # Performance optimization - avoid copying first_columns if there is only one table.
1009            return SingleValuedMapping(first_columns, first_table)
1010
1011        unambiguous_columns = {col: first_table for col in first_columns}
1012        all_columns = set(unambiguous_columns)
1013
1014        for table, columns in source_columns_pairs[1:]:
1015            unique = set(columns)
1016            ambiguous = all_columns.intersection(unique)
1017            all_columns.update(columns)
1018
1019            for column in ambiguous:
1020                unambiguous_columns.pop(column, None)
1021            for column in unique.difference(ambiguous):
1022                unambiguous_columns[column] = table
1023
1024        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27    dialect: DialectType = None,
 28) -> exp.Expression:
 29    """
 30    Rewrite sqlglot AST to have fully qualified columns.
 31
 32    Example:
 33        >>> import sqlglot
 34        >>> schema = {"tbl": {"col": "INT"}}
 35        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 36        >>> qualify_columns(expression, schema).sql()
 37        'SELECT tbl.col AS col FROM tbl'
 38
 39    Args:
 40        expression: Expression to qualify.
 41        schema: Database schema.
 42        expand_alias_refs: Whether to expand references to aliases.
 43        expand_stars: Whether to expand star queries. This is a necessary step
 44            for most of the optimizer's rules to work; do not set to False unless you
 45            know what you're doing!
 46        infer_schema: Whether to infer the schema if missing.
 47        allow_partial_qualification: Whether to allow partial qualification.
 48
 49    Returns:
 50        The qualified expression.
 51
 52    Notes:
 53        - Currently only handles a single PIVOT or UNPIVOT operator
 54    """
 55    schema = ensure_schema(schema, dialect=dialect)
 56    annotator = TypeAnnotator(schema)
 57    infer_schema = schema.empty if infer_schema is None else infer_schema
 58    dialect = Dialect.get_or_raise(schema.dialect)
 59    pseudocolumns = dialect.PSEUDOCOLUMNS
 60    bigquery = dialect == "bigquery"
 61
 62    for scope in traverse_scope(expression):
 63        scope_expression = scope.expression
 64        is_select = isinstance(scope_expression, exp.Select)
 65
 66        if is_select and scope_expression.args.get("connect"):
 67            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 68            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 69            scope_expression.transform(
 70                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 71                copy=False,
 72            )
 73            scope.clear_cache()
 74
 75        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 76        _pop_table_column_aliases(scope.ctes)
 77        _pop_table_column_aliases(scope.derived_tables)
 78        using_column_tables = _expand_using(scope, resolver)
 79
 80        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 81            _expand_alias_refs(
 82                scope,
 83                resolver,
 84                dialect,
 85                expand_only_groupby=bigquery,
 86            )
 87
 88        _convert_columns_to_dots(scope, resolver)
 89        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106
107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
109        _expand_order_by_and_distinct_on(scope, resolver)
110
111        if bigquery:
112            annotator.annotate_scope(scope)
113
114    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E:
118    """Raise an `OptimizeError` if any columns aren't qualified"""
119    all_unqualified_columns = []
120    for scope in traverse_scope(expression):
121        if isinstance(scope.expression, exp.Select):
122            unqualified_columns = scope.unqualified_columns
123
124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125                column = scope.external_columns[0]
126                for_table = f" for table: '{column.table}'" if column.table else ""
127                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
128
129            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
130                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
131                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
132                # this list here to ensure those in the former category will be excluded.
133                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
134                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
135
136            all_unqualified_columns.extend(unqualified_columns)
137
138    if all_unqualified_columns:
139        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
140
141    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
762def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
763    """Ensure all output columns are aliased"""
764    if isinstance(scope_or_expression, exp.Expression):
765        scope = build_scope(scope_or_expression)
766        if not isinstance(scope, Scope):
767            return
768    else:
769        scope = scope_or_expression
770
771    new_selections = []
772    for i, (selection, aliased_column) in enumerate(
773        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
774    ):
775        if selection is None or isinstance(selection, exp.QueryTransform):
776            break
777
778        if isinstance(selection, exp.Subquery):
779            if not selection.output_name:
780                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
781        elif not isinstance(selection, exp.Alias) and not selection.is_star:
782            selection = alias(
783                selection,
784                alias=selection.output_name or f"_col_{i}",
785                copy=False,
786            )
787        if aliased_column:
788            selection.set("alias", exp.to_identifier(aliased_column))
789
790        new_selections.append(selection)
791
792    if new_selections and isinstance(scope.expression, exp.Select):
793        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
796def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
797    """Makes sure all identifiers that need to be quoted are quoted."""
798    return expression.transform(
799        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
800    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
803def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
804    """
805    Pushes down the CTE alias columns into the projection,
806
807    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
808
809    Example:
810        >>> import sqlglot
811        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
812        >>> pushdown_cte_alias_columns(expression).sql()
813        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
814
815    Args:
816        expression: Expression to pushdown.
817
818    Returns:
819        The expression with the CTE aliases pushed down into the projection.
820    """
821    for cte in expression.find_all(exp.CTE):
822        if cte.alias_column_names:
823            new_expressions = []
824            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
825                if isinstance(projection, exp.Alias):
826                    projection.set("alias", _alias)
827                else:
828                    projection = alias(projection, alias=_alias)
829                new_expressions.append(projection)
830            cte.this.set("expressions", new_expressions)
831
832    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
 835class Resolver:
 836    """
 837    Helper for resolving columns.
 838
 839    This is a class so we can lazily load some things and easily share them across functions.
 840    """
 841
 842    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 843        self.scope = scope
 844        self.schema = schema
 845        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 846        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 847        self._all_columns: t.Optional[t.Set[str]] = None
 848        self._infer_schema = infer_schema
 849        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 850
 851    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 852        """
 853        Get the table for a column name.
 854
 855        Args:
 856            column_name: The column name to find the table for.
 857        Returns:
 858            The table name if it can be found/inferred.
 859        """
 860        if self._unambiguous_columns is None:
 861            self._unambiguous_columns = self._get_unambiguous_columns(
 862                self._get_all_source_columns()
 863            )
 864
 865        table_name = self._unambiguous_columns.get(column_name)
 866
 867        if not table_name and self._infer_schema:
 868            sources_without_schema = tuple(
 869                source
 870                for source, columns in self._get_all_source_columns().items()
 871                if not columns or "*" in columns
 872            )
 873            if len(sources_without_schema) == 1:
 874                table_name = sources_without_schema[0]
 875
 876        if table_name not in self.scope.selected_sources:
 877            return exp.to_identifier(table_name)
 878
 879        node, _ = self.scope.selected_sources.get(table_name)
 880
 881        if isinstance(node, exp.Query):
 882            while node and node.alias != table_name:
 883                node = node.parent
 884
 885        node_alias = node.args.get("alias")
 886        if node_alias:
 887            return exp.to_identifier(node_alias.this)
 888
 889        return exp.to_identifier(table_name)
 890
 891    @property
 892    def all_columns(self) -> t.Set[str]:
 893        """All available columns of all sources in this scope"""
 894        if self._all_columns is None:
 895            self._all_columns = {
 896                column for columns in self._get_all_source_columns().values() for column in columns
 897            }
 898        return self._all_columns
 899
 900    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 901        """Resolve the source columns for a given source `name`."""
 902        cache_key = (name, only_visible)
 903        if cache_key not in self._get_source_columns_cache:
 904            if name not in self.scope.sources:
 905                raise OptimizeError(f"Unknown table: {name}")
 906
 907            source = self.scope.sources[name]
 908
 909            if isinstance(source, exp.Table):
 910                columns = self.schema.column_names(source, only_visible)
 911            elif isinstance(source, Scope) and isinstance(
 912                source.expression, (exp.Values, exp.Unnest)
 913            ):
 914                columns = source.expression.named_selects
 915
 916                # in bigquery, unnest structs are automatically scoped as tables, so you can
 917                # directly select a struct field in a query.
 918                # this handles the case where the unnest is statically defined.
 919                if self.schema.dialect == "bigquery":
 920                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 921                        for k in source.expression.type.expressions:  # type: ignore
 922                            columns.append(k.name)
 923            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
 924                set_op = source.expression
 925
 926                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
 927                on_column_list = set_op.args.get("on")
 928
 929                if on_column_list:
 930                    # The resulting columns are the columns in the ON clause:
 931                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
 932                    columns = [col.name for col in on_column_list]
 933                elif set_op.side or set_op.kind:
 934                    side = set_op.side
 935                    kind = set_op.kind
 936
 937                    left = set_op.left.named_selects
 938                    right = set_op.right.named_selects
 939
 940                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
 941                    if side == "LEFT":
 942                        columns = left
 943                    elif side == "FULL":
 944                        columns = list(dict.fromkeys(left + right))
 945                    elif kind == "INNER":
 946                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
 947                else:
 948                    columns = set_op.named_selects
 949            else:
 950                select = seq_get(source.expression.selects, 0)
 951
 952                if isinstance(select, exp.QueryTransform):
 953                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
 954                    schema = select.args.get("schema")
 955                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
 956                else:
 957                    columns = source.expression.named_selects
 958
 959            node, _ = self.scope.selected_sources.get(name) or (None, None)
 960            if isinstance(node, Scope):
 961                column_aliases = node.expression.alias_column_names
 962            elif isinstance(node, exp.Expression):
 963                column_aliases = node.alias_column_names
 964            else:
 965                column_aliases = []
 966
 967            if column_aliases:
 968                # If the source's columns are aliased, their aliases shadow the corresponding column names.
 969                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
 970                columns = [
 971                    alias or name
 972                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
 973                ]
 974
 975            self._get_source_columns_cache[cache_key] = columns
 976
 977        return self._get_source_columns_cache[cache_key]
 978
 979    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
 980        if self._source_columns is None:
 981            self._source_columns = {
 982                source_name: self.get_source_columns(source_name)
 983                for source_name, source in itertools.chain(
 984                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
 985                )
 986            }
 987        return self._source_columns
 988
 989    def _get_unambiguous_columns(
 990        self, source_columns: t.Dict[str, t.Sequence[str]]
 991    ) -> t.Mapping[str, str]:
 992        """
 993        Find all the unambiguous columns in sources.
 994
 995        Args:
 996            source_columns: Mapping of names to source columns.
 997
 998        Returns:
 999            Mapping of column name to source name.
1000        """
1001        if not source_columns:
1002            return {}
1003
1004        source_columns_pairs = list(source_columns.items())
1005
1006        first_table, first_columns = source_columns_pairs[0]
1007
1008        if len(source_columns_pairs) == 1:
1009            # Performance optimization - avoid copying first_columns if there is only one table.
1010            return SingleValuedMapping(first_columns, first_table)
1011
1012        unambiguous_columns = {col: first_table for col in first_columns}
1013        all_columns = set(unambiguous_columns)
1014
1015        for table, columns in source_columns_pairs[1:]:
1016            unique = set(columns)
1017            ambiguous = all_columns.intersection(unique)
1018            all_columns.update(columns)
1019
1020            for column in ambiguous:
1021                unambiguous_columns.pop(column, None)
1022            for column in unique.difference(ambiguous):
1023                unambiguous_columns[column] = table
1024
1025        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
842    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
843        self.scope = scope
844        self.schema = schema
845        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
846        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
847        self._all_columns: t.Optional[t.Set[str]] = None
848        self._infer_schema = infer_schema
849        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
851    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
852        """
853        Get the table for a column name.
854
855        Args:
856            column_name: The column name to find the table for.
857        Returns:
858            The table name if it can be found/inferred.
859        """
860        if self._unambiguous_columns is None:
861            self._unambiguous_columns = self._get_unambiguous_columns(
862                self._get_all_source_columns()
863            )
864
865        table_name = self._unambiguous_columns.get(column_name)
866
867        if not table_name and self._infer_schema:
868            sources_without_schema = tuple(
869                source
870                for source, columns in self._get_all_source_columns().items()
871                if not columns or "*" in columns
872            )
873            if len(sources_without_schema) == 1:
874                table_name = sources_without_schema[0]
875
876        if table_name not in self.scope.selected_sources:
877            return exp.to_identifier(table_name)
878
879        node, _ = self.scope.selected_sources.get(table_name)
880
881        if isinstance(node, exp.Query):
882            while node and node.alias != table_name:
883                node = node.parent
884
885        node_alias = node.args.get("alias")
886        if node_alias:
887            return exp.to_identifier(node_alias.this)
888
889        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
891    @property
892    def all_columns(self) -> t.Set[str]:
893        """All available columns of all sources in this scope"""
894        if self._all_columns is None:
895            self._all_columns = {
896                column for columns in self._get_all_source_columns().values() for column in columns
897            }
898        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
900    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
901        """Resolve the source columns for a given source `name`."""
902        cache_key = (name, only_visible)
903        if cache_key not in self._get_source_columns_cache:
904            if name not in self.scope.sources:
905                raise OptimizeError(f"Unknown table: {name}")
906
907            source = self.scope.sources[name]
908
909            if isinstance(source, exp.Table):
910                columns = self.schema.column_names(source, only_visible)
911            elif isinstance(source, Scope) and isinstance(
912                source.expression, (exp.Values, exp.Unnest)
913            ):
914                columns = source.expression.named_selects
915
916                # in bigquery, unnest structs are automatically scoped as tables, so you can
917                # directly select a struct field in a query.
918                # this handles the case where the unnest is statically defined.
919                if self.schema.dialect == "bigquery":
920                    if source.expression.is_type(exp.DataType.Type.STRUCT):
921                        for k in source.expression.type.expressions:  # type: ignore
922                            columns.append(k.name)
923            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
924                set_op = source.expression
925
926                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
927                on_column_list = set_op.args.get("on")
928
929                if on_column_list:
930                    # The resulting columns are the columns in the ON clause:
931                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
932                    columns = [col.name for col in on_column_list]
933                elif set_op.side or set_op.kind:
934                    side = set_op.side
935                    kind = set_op.kind
936
937                    left = set_op.left.named_selects
938                    right = set_op.right.named_selects
939
940                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
941                    if side == "LEFT":
942                        columns = left
943                    elif side == "FULL":
944                        columns = list(dict.fromkeys(left + right))
945                    elif kind == "INNER":
946                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
947                else:
948                    columns = set_op.named_selects
949            else:
950                select = seq_get(source.expression.selects, 0)
951
952                if isinstance(select, exp.QueryTransform):
953                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
954                    schema = select.args.get("schema")
955                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
956                else:
957                    columns = source.expression.named_selects
958
959            node, _ = self.scope.selected_sources.get(name) or (None, None)
960            if isinstance(node, Scope):
961                column_aliases = node.expression.alias_column_names
962            elif isinstance(node, exp.Expression):
963                column_aliases = node.alias_column_names
964            else:
965                column_aliases = []
966
967            if column_aliases:
968                # If the source's columns are aliased, their aliases shadow the corresponding column names.
969                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
970                columns = [
971                    alias or name
972                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
973                ]
974
975            self._get_source_columns_cache[cache_key] = columns
976
977        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.