Edit on GitHub

sqlglot.optimizer.qualify_columns

   1from __future__ import annotations
   2
   3import itertools
   4import typing as t
   5
   6from sqlglot import alias, exp
   7from sqlglot.dialects.dialect import Dialect, DialectType
   8from sqlglot.errors import OptimizeError
   9from sqlglot.helper import seq_get, SingleValuedMapping
  10from sqlglot.optimizer.annotate_types import TypeAnnotator
  11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
  12from sqlglot.optimizer.simplify import simplify_parens
  13from sqlglot.schema import Schema, ensure_schema
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot._typing import E
  17
  18
  19def qualify_columns(
  20    expression: exp.Expression,
  21    schema: t.Dict | Schema,
  22    expand_alias_refs: bool = True,
  23    expand_stars: bool = True,
  24    infer_schema: t.Optional[bool] = None,
  25    allow_partial_qualification: bool = False,
  26) -> exp.Expression:
  27    """
  28    Rewrite sqlglot AST to have fully qualified columns.
  29
  30    Example:
  31        >>> import sqlglot
  32        >>> schema = {"tbl": {"col": "INT"}}
  33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
  34        >>> qualify_columns(expression, schema).sql()
  35        'SELECT tbl.col AS col FROM tbl'
  36
  37    Args:
  38        expression: Expression to qualify.
  39        schema: Database schema.
  40        expand_alias_refs: Whether to expand references to aliases.
  41        expand_stars: Whether to expand star queries. This is a necessary step
  42            for most of the optimizer's rules to work; do not set to False unless you
  43            know what you're doing!
  44        infer_schema: Whether to infer the schema if missing.
  45        allow_partial_qualification: Whether to allow partial qualification.
  46
  47    Returns:
  48        The qualified expression.
  49
  50    Notes:
  51        - Currently only handles a single PIVOT or UNPIVOT operator
  52    """
  53    schema = ensure_schema(schema)
  54    annotator = TypeAnnotator(schema)
  55    infer_schema = schema.empty if infer_schema is None else infer_schema
  56    dialect = Dialect.get_or_raise(schema.dialect)
  57    pseudocolumns = dialect.PSEUDOCOLUMNS
  58    bigquery = dialect == "bigquery"
  59
  60    for scope in traverse_scope(expression):
  61        scope_expression = scope.expression
  62        is_select = isinstance(scope_expression, exp.Select)
  63
  64        if is_select and scope_expression.args.get("connect"):
  65            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
  66            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
  67            scope_expression.transform(
  68                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
  69                copy=False,
  70            )
  71            scope.clear_cache()
  72
  73        resolver = Resolver(scope, schema, infer_schema=infer_schema)
  74        _pop_table_column_aliases(scope.ctes)
  75        _pop_table_column_aliases(scope.derived_tables)
  76        using_column_tables = _expand_using(scope, resolver)
  77
  78        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
  79            _expand_alias_refs(
  80                scope,
  81                resolver,
  82                dialect,
  83                expand_only_groupby=bigquery,
  84            )
  85
  86        _convert_columns_to_dots(scope, resolver)
  87        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
  88
  89        if not schema.empty and expand_alias_refs:
  90            _expand_alias_refs(scope, resolver, dialect)
  91
  92        if is_select:
  93            if expand_stars:
  94                _expand_stars(
  95                    scope,
  96                    resolver,
  97                    using_column_tables,
  98                    pseudocolumns,
  99                    annotator,
 100                )
 101            qualify_outputs(scope)
 102
 103        _expand_group_by(scope, dialect)
 104
 105        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
 106        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
 107        _expand_order_by_and_distinct_on(scope, resolver)
 108
 109        if bigquery:
 110            annotator.annotate_scope(scope)
 111
 112    return expression
 113
 114
 115def validate_qualify_columns(expression: E) -> E:
 116    """Raise an `OptimizeError` if any columns aren't qualified"""
 117    all_unqualified_columns = []
 118    for scope in traverse_scope(expression):
 119        if isinstance(scope.expression, exp.Select):
 120            unqualified_columns = scope.unqualified_columns
 121
 122            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 123                column = scope.external_columns[0]
 124                for_table = f" for table: '{column.table}'" if column.table else ""
 125                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 126
 127            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 128                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 129                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 130                # this list here to ensure those in the former category will be excluded.
 131                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 132                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 133
 134            all_unqualified_columns.extend(unqualified_columns)
 135
 136    if all_unqualified_columns:
 137        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
 138
 139    return expression
 140
 141
 142def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
 143    name_columns = [
 144        field.this
 145        for field in unpivot.fields
 146        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
 147    ]
 148    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
 149
 150    return itertools.chain(name_columns, value_columns)
 151
 152
 153def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 154    """
 155    Remove table column aliases.
 156
 157    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 158    """
 159    for derived_table in derived_tables:
 160        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
 161            continue
 162        table_alias = derived_table.args.get("alias")
 163        if table_alias:
 164            table_alias.args.pop("columns", None)
 165
 166
 167def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 168    columns = {}
 169
 170    def _update_source_columns(source_name: str) -> None:
 171        for column_name in resolver.get_source_columns(source_name):
 172            if column_name not in columns:
 173                columns[column_name] = source_name
 174
 175    joins = list(scope.find_all(exp.Join))
 176    names = {join.alias_or_name for join in joins}
 177    ordered = [key for key in scope.selected_sources if key not in names]
 178
 179    if names and not ordered:
 180        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
 181
 182    # Mapping of automatically joined column names to an ordered set of source names (dict).
 183    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
 184
 185    for source_name in ordered:
 186        _update_source_columns(source_name)
 187
 188    for i, join in enumerate(joins):
 189        source_table = ordered[-1]
 190        if source_table:
 191            _update_source_columns(source_table)
 192
 193        join_table = join.alias_or_name
 194        ordered.append(join_table)
 195
 196        using = join.args.get("using")
 197        if not using:
 198            continue
 199
 200        join_columns = resolver.get_source_columns(join_table)
 201        conditions = []
 202        using_identifier_count = len(using)
 203        is_semi_or_anti_join = join.is_semi_or_anti_join
 204
 205        for identifier in using:
 206            identifier = identifier.name
 207            table = columns.get(identifier)
 208
 209            if not table or identifier not in join_columns:
 210                if (columns and "*" not in columns) and join_columns:
 211                    raise OptimizeError(f"Cannot automatically join: {identifier}")
 212
 213            table = table or source_table
 214
 215            if i == 0 or using_identifier_count == 1:
 216                lhs: exp.Expression = exp.column(identifier, table=table)
 217            else:
 218                coalesce_columns = [
 219                    exp.column(identifier, table=t)
 220                    for t in ordered[:-1]
 221                    if identifier in resolver.get_source_columns(t)
 222                ]
 223                if len(coalesce_columns) > 1:
 224                    lhs = exp.func("coalesce", *coalesce_columns)
 225                else:
 226                    lhs = exp.column(identifier, table=table)
 227
 228            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
 229
 230            # Set all values in the dict to None, because we only care about the key ordering
 231            tables = column_tables.setdefault(identifier, {})
 232
 233            # Do not update the dict if this was a SEMI/ANTI join in
 234            # order to avoid generating COALESCE columns for this join pair
 235            if not is_semi_or_anti_join:
 236                if table not in tables:
 237                    tables[table] = None
 238                if join_table not in tables:
 239                    tables[join_table] = None
 240
 241        join.args.pop("using")
 242        join.set("on", exp.and_(*conditions, copy=False))
 243
 244    if column_tables:
 245        for column in scope.columns:
 246            if not column.table and column.name in column_tables:
 247                tables = column_tables[column.name]
 248                coalesce_args = [exp.column(column.name, table=table) for table in tables]
 249                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
 250
 251                if isinstance(column.parent, exp.Select):
 252                    # Ensure the USING column keeps its name if it's projected
 253                    replacement = alias(replacement, alias=column.name, copy=False)
 254                elif isinstance(column.parent, exp.Struct):
 255                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
 256                    replacement = exp.PropertyEQ(
 257                        this=exp.to_identifier(column.name), expression=replacement
 258                    )
 259
 260                scope.replace(column, replacement)
 261
 262    return column_tables
 263
 264
 265def _expand_alias_refs(
 266    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
 267) -> None:
 268    """
 269    Expand references to aliases.
 270    Example:
 271        SELECT y.foo AS bar, bar * 2 AS baz FROM y
 272     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
 273    """
 274    expression = scope.expression
 275
 276    if not isinstance(expression, exp.Select) or dialect == "oracle":
 277        return
 278
 279    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
 280    projections = {s.alias_or_name for s in expression.selects}
 281
 282    def replace_columns(
 283        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
 284    ) -> None:
 285        is_group_by = isinstance(node, exp.Group)
 286        is_having = isinstance(node, exp.Having)
 287        if not node or (expand_only_groupby and not is_group_by):
 288            return
 289
 290        for column in walk_in_scope(node, prune=lambda node: node.is_star):
 291            if not isinstance(column, exp.Column):
 292                continue
 293
 294            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
 295            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
 296            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
 297            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
 298            if expand_only_groupby and is_group_by and column.parent is not node:
 299                continue
 300
 301            skip_replace = False
 302            table = resolver.get_table(column.name) if resolve_table and not column.table else None
 303            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
 304
 305            if alias_expr:
 306                skip_replace = bool(
 307                    alias_expr.find(exp.AggFunc)
 308                    and column.find_ancestor(exp.AggFunc)
 309                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
 310                )
 311
 312                # BigQuery's having clause gets confused if an alias matches a source.
 313                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
 314                # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table
 315                if is_having and dialect == "bigquery":
 316                    skip_replace = skip_replace or any(
 317                        node.parts[0].name in projections
 318                        for node in alias_expr.find_all(exp.Column)
 319                    )
 320
 321            if table and (not alias_expr or skip_replace):
 322                column.set("table", table)
 323            elif not column.table and alias_expr and not skip_replace:
 324                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
 325                    if literal_index:
 326                        column.replace(exp.Literal.number(i))
 327                else:
 328                    column = column.replace(exp.paren(alias_expr))
 329                    simplified = simplify_parens(column)
 330                    if simplified is not column:
 331                        column.replace(simplified)
 332
 333    for i, projection in enumerate(expression.selects):
 334        replace_columns(projection)
 335        if isinstance(projection, exp.Alias):
 336            alias_to_expression[projection.alias] = (projection.this, i + 1)
 337
 338    parent_scope = scope
 339    while parent_scope.is_union:
 340        parent_scope = parent_scope.parent
 341
 342    # We shouldn't expand aliases if they match the recursive CTE's columns
 343    if parent_scope.is_cte:
 344        cte = parent_scope.expression.parent
 345        if cte.find_ancestor(exp.With).recursive:
 346            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
 347                alias_to_expression.pop(recursive_cte_column.output_name, None)
 348
 349    replace_columns(expression.args.get("where"))
 350    replace_columns(expression.args.get("group"), literal_index=True)
 351    replace_columns(expression.args.get("having"), resolve_table=True)
 352    replace_columns(expression.args.get("qualify"), resolve_table=True)
 353
 354    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
 355    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
 356    if dialect == "snowflake":
 357        for join in expression.args.get("joins") or []:
 358            replace_columns(join)
 359
 360    scope.clear_cache()
 361
 362
 363def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
 364    expression = scope.expression
 365    group = expression.args.get("group")
 366    if not group:
 367        return
 368
 369    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
 370    expression.set("group", group)
 371
 372
 373def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
 374    for modifier_key in ("order", "distinct"):
 375        modifier = scope.expression.args.get(modifier_key)
 376        if isinstance(modifier, exp.Distinct):
 377            modifier = modifier.args.get("on")
 378
 379        if not isinstance(modifier, exp.Expression):
 380            continue
 381
 382        modifier_expressions = modifier.expressions
 383        if modifier_key == "order":
 384            modifier_expressions = [ordered.this for ordered in modifier_expressions]
 385
 386        for original, expanded in zip(
 387            modifier_expressions,
 388            _expand_positional_references(
 389                scope, modifier_expressions, resolver.schema.dialect, alias=True
 390            ),
 391        ):
 392            for agg in original.find_all(exp.AggFunc):
 393                for col in agg.find_all(exp.Column):
 394                    if not col.table:
 395                        col.set("table", resolver.get_table(col.name))
 396
 397            original.replace(expanded)
 398
 399        if scope.expression.args.get("group"):
 400            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
 401
 402            for expression in modifier_expressions:
 403                expression.replace(
 404                    exp.to_identifier(_select_by_pos(scope, expression).alias)
 405                    if expression.is_int
 406                    else selects.get(expression, expression)
 407                )
 408
 409
 410def _expand_positional_references(
 411    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
 412) -> t.List[exp.Expression]:
 413    new_nodes: t.List[exp.Expression] = []
 414    ambiguous_projections = None
 415
 416    for node in expressions:
 417        if node.is_int:
 418            select = _select_by_pos(scope, t.cast(exp.Literal, node))
 419
 420            if alias:
 421                new_nodes.append(exp.column(select.args["alias"].copy()))
 422            else:
 423                select = select.this
 424
 425                if dialect == "bigquery":
 426                    if ambiguous_projections is None:
 427                        # When a projection name is also a source name and it is referenced in the
 428                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
 429                        ambiguous_projections = {
 430                            s.alias_or_name
 431                            for s in scope.expression.selects
 432                            if s.alias_or_name in scope.selected_sources
 433                        }
 434
 435                    ambiguous = any(
 436                        column.parts[0].name in ambiguous_projections
 437                        for column in select.find_all(exp.Column)
 438                    )
 439                else:
 440                    ambiguous = False
 441
 442                if (
 443                    isinstance(select, exp.CONSTANTS)
 444                    or select.find(exp.Explode, exp.Unnest)
 445                    or ambiguous
 446                ):
 447                    new_nodes.append(node)
 448                else:
 449                    new_nodes.append(select.copy())
 450        else:
 451            new_nodes.append(node)
 452
 453    return new_nodes
 454
 455
 456def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
 457    try:
 458        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
 459    except IndexError:
 460        raise OptimizeError(f"Unknown output column: {node.name}")
 461
 462
 463def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
 464    """
 465    Converts `Column` instances that represent struct field lookup into chained `Dots`.
 466
 467    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
 468    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
 469    """
 470    converted = False
 471    for column in itertools.chain(scope.columns, scope.stars):
 472        if isinstance(column, exp.Dot):
 473            continue
 474
 475        column_table: t.Optional[str | exp.Identifier] = column.table
 476        if (
 477            column_table
 478            and column_table not in scope.sources
 479            and (
 480                not scope.parent
 481                or column_table not in scope.parent.sources
 482                or not scope.is_correlated_subquery
 483            )
 484        ):
 485            root, *parts = column.parts
 486
 487            if root.name in scope.sources:
 488                # The struct is already qualified, but we still need to change the AST
 489                column_table = root
 490                root, *parts = parts
 491            else:
 492                column_table = resolver.get_table(root.name)
 493
 494            if column_table:
 495                converted = True
 496                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
 497
 498    if converted:
 499        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
 500        # a `for column in scope.columns` iteration, even though they shouldn't be
 501        scope.clear_cache()
 502
 503
 504def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
 505    """Disambiguate columns, ensuring each column specifies a source"""
 506    for column in scope.columns:
 507        column_table = column.table
 508        column_name = column.name
 509
 510        if column_table and column_table in scope.sources:
 511            source_columns = resolver.get_source_columns(column_table)
 512            if (
 513                not allow_partial_qualification
 514                and source_columns
 515                and column_name not in source_columns
 516                and "*" not in source_columns
 517            ):
 518                raise OptimizeError(f"Unknown column: {column_name}")
 519
 520        if not column_table:
 521            if scope.pivots and not column.find_ancestor(exp.Pivot):
 522                # If the column is under the Pivot expression, we need to qualify it
 523                # using the name of the pivoted source instead of the pivot's alias
 524                column.set("table", exp.to_identifier(scope.pivots[0].alias))
 525                continue
 526
 527            # column_table can be a '' because bigquery unnest has no table alias
 528            column_table = resolver.get_table(column_name)
 529            if column_table:
 530                column.set("table", column_table)
 531
 532    for pivot in scope.pivots:
 533        for column in pivot.find_all(exp.Column):
 534            if not column.table and column.name in resolver.all_columns:
 535                column_table = resolver.get_table(column.name)
 536                if column_table:
 537                    column.set("table", column_table)
 538
 539
 540def _expand_struct_stars(
 541    expression: exp.Dot,
 542) -> t.List[exp.Alias]:
 543    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
 544
 545    dot_column = t.cast(exp.Column, expression.find(exp.Column))
 546    if not dot_column.is_type(exp.DataType.Type.STRUCT):
 547        return []
 548
 549    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
 550    dot_column = dot_column.copy()
 551    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
 552
 553    # First part is the table name and last part is the star so they can be dropped
 554    dot_parts = expression.parts[1:-1]
 555
 556    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
 557    for part in dot_parts[1:]:
 558        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 559            # Unable to expand star unless all fields are named
 560            if not isinstance(field.this, exp.Identifier):
 561                return []
 562
 563            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
 564                starting_struct = field
 565                break
 566        else:
 567            # There is no matching field in the struct
 568            return []
 569
 570    taken_names = set()
 571    new_selections = []
 572
 573    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 574        name = field.name
 575
 576        # Ambiguous or anonymous fields can't be expanded
 577        if name in taken_names or not isinstance(field.this, exp.Identifier):
 578            return []
 579
 580        taken_names.add(name)
 581
 582        this = field.this.copy()
 583        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
 584        new_column = exp.column(
 585            t.cast(exp.Identifier, root),
 586            table=dot_column.args.get("table"),
 587            fields=t.cast(t.List[exp.Identifier], parts),
 588        )
 589        new_selections.append(alias(new_column, this, copy=False))
 590
 591    return new_selections
 592
 593
 594def _expand_stars(
 595    scope: Scope,
 596    resolver: Resolver,
 597    using_column_tables: t.Dict[str, t.Any],
 598    pseudocolumns: t.Set[str],
 599    annotator: TypeAnnotator,
 600) -> None:
 601    """Expand stars to lists of column selections"""
 602
 603    new_selections: t.List[exp.Expression] = []
 604    except_columns: t.Dict[int, t.Set[str]] = {}
 605    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
 606    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
 607
 608    coalesced_columns = set()
 609    dialect = resolver.schema.dialect
 610
 611    pivot_output_columns = None
 612    pivot_exclude_columns: t.Set[str] = set()
 613
 614    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
 615    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
 616        if pivot.unpivot:
 617            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
 618
 619            for field in pivot.fields:
 620                if isinstance(field, exp.In):
 621                    pivot_exclude_columns.update(
 622                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
 623                    )
 624
 625        else:
 626            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
 627
 628            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
 629            if not pivot_output_columns:
 630                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
 631
 632    is_bigquery = dialect == "bigquery"
 633    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
 634        # Found struct expansion, annotate scope ahead of time
 635        annotator.annotate_scope(scope)
 636
 637    for expression in scope.expression.selects:
 638        tables = []
 639        if isinstance(expression, exp.Star):
 640            tables.extend(scope.selected_sources)
 641            _add_except_columns(expression, tables, except_columns)
 642            _add_replace_columns(expression, tables, replace_columns)
 643            _add_rename_columns(expression, tables, rename_columns)
 644        elif expression.is_star:
 645            if not isinstance(expression, exp.Dot):
 646                tables.append(expression.table)
 647                _add_except_columns(expression.this, tables, except_columns)
 648                _add_replace_columns(expression.this, tables, replace_columns)
 649                _add_rename_columns(expression.this, tables, rename_columns)
 650            elif is_bigquery:
 651                struct_fields = _expand_struct_stars(expression)
 652                if struct_fields:
 653                    new_selections.extend(struct_fields)
 654                    continue
 655
 656        if not tables:
 657            new_selections.append(expression)
 658            continue
 659
 660        for table in tables:
 661            if table not in scope.sources:
 662                raise OptimizeError(f"Unknown table: {table}")
 663
 664            columns = resolver.get_source_columns(table, only_visible=True)
 665            columns = columns or scope.outer_columns
 666
 667            if pseudocolumns:
 668                columns = [name for name in columns if name.upper() not in pseudocolumns]
 669
 670            if not columns or "*" in columns:
 671                return
 672
 673            table_id = id(table)
 674            columns_to_exclude = except_columns.get(table_id) or set()
 675            renamed_columns = rename_columns.get(table_id, {})
 676            replaced_columns = replace_columns.get(table_id, {})
 677
 678            if pivot:
 679                if pivot_output_columns and pivot_exclude_columns:
 680                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
 681                    pivot_columns.extend(pivot_output_columns)
 682                else:
 683                    pivot_columns = pivot.alias_column_names
 684
 685                if pivot_columns:
 686                    new_selections.extend(
 687                        alias(exp.column(name, table=pivot.alias), name, copy=False)
 688                        for name in pivot_columns
 689                        if name not in columns_to_exclude
 690                    )
 691                    continue
 692
 693            for name in columns:
 694                if name in columns_to_exclude or name in coalesced_columns:
 695                    continue
 696                if name in using_column_tables and table in using_column_tables[name]:
 697                    coalesced_columns.add(name)
 698                    tables = using_column_tables[name]
 699                    coalesce_args = [exp.column(name, table=table) for table in tables]
 700
 701                    new_selections.append(
 702                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
 703                    )
 704                else:
 705                    alias_ = renamed_columns.get(name, name)
 706                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
 707                    new_selections.append(
 708                        alias(selection_expr, alias_, copy=False)
 709                        if alias_ != name
 710                        else selection_expr
 711                    )
 712
 713    # Ensures we don't overwrite the initial selections with an empty list
 714    if new_selections and isinstance(scope.expression, exp.Select):
 715        scope.expression.set("expressions", new_selections)
 716
 717
 718def _add_except_columns(
 719    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
 720) -> None:
 721    except_ = expression.args.get("except")
 722
 723    if not except_:
 724        return
 725
 726    columns = {e.name for e in except_}
 727
 728    for table in tables:
 729        except_columns[id(table)] = columns
 730
 731
 732def _add_rename_columns(
 733    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
 734) -> None:
 735    rename = expression.args.get("rename")
 736
 737    if not rename:
 738        return
 739
 740    columns = {e.this.name: e.alias for e in rename}
 741
 742    for table in tables:
 743        rename_columns[id(table)] = columns
 744
 745
 746def _add_replace_columns(
 747    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
 748) -> None:
 749    replace = expression.args.get("replace")
 750
 751    if not replace:
 752        return
 753
 754    columns = {e.alias: e for e in replace}
 755
 756    for table in tables:
 757        replace_columns[id(table)] = columns
 758
 759
 760def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
 761    """Ensure all output columns are aliased"""
 762    if isinstance(scope_or_expression, exp.Expression):
 763        scope = build_scope(scope_or_expression)
 764        if not isinstance(scope, Scope):
 765            return
 766    else:
 767        scope = scope_or_expression
 768
 769    new_selections = []
 770    for i, (selection, aliased_column) in enumerate(
 771        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
 772    ):
 773        if selection is None:
 774            break
 775
 776        if isinstance(selection, exp.Subquery):
 777            if not selection.output_name:
 778                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
 779        elif not isinstance(selection, exp.Alias) and not selection.is_star:
 780            selection = alias(
 781                selection,
 782                alias=selection.output_name or f"_col_{i}",
 783                copy=False,
 784            )
 785        if aliased_column:
 786            selection.set("alias", exp.to_identifier(aliased_column))
 787
 788        new_selections.append(selection)
 789
 790    if isinstance(scope.expression, exp.Select):
 791        scope.expression.set("expressions", new_selections)
 792
 793
 794def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
 795    """Makes sure all identifiers that need to be quoted are quoted."""
 796    return expression.transform(
 797        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
 798    )  # type: ignore
 799
 800
 801def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
 802    """
 803    Pushes down the CTE alias columns into the projection,
 804
 805    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
 806
 807    Example:
 808        >>> import sqlglot
 809        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
 810        >>> pushdown_cte_alias_columns(expression).sql()
 811        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
 812
 813    Args:
 814        expression: Expression to pushdown.
 815
 816    Returns:
 817        The expression with the CTE aliases pushed down into the projection.
 818    """
 819    for cte in expression.find_all(exp.CTE):
 820        if cte.alias_column_names:
 821            new_expressions = []
 822            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
 823                if isinstance(projection, exp.Alias):
 824                    projection.set("alias", _alias)
 825                else:
 826                    projection = alias(projection, alias=_alias)
 827                new_expressions.append(projection)
 828            cte.this.set("expressions", new_expressions)
 829
 830    return expression
 831
 832
 833class Resolver:
 834    """
 835    Helper for resolving columns.
 836
 837    This is a class so we can lazily load some things and easily share them across functions.
 838    """
 839
 840    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 841        self.scope = scope
 842        self.schema = schema
 843        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 844        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 845        self._all_columns: t.Optional[t.Set[str]] = None
 846        self._infer_schema = infer_schema
 847        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 848
 849    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 850        """
 851        Get the table for a column name.
 852
 853        Args:
 854            column_name: The column name to find the table for.
 855        Returns:
 856            The table name if it can be found/inferred.
 857        """
 858        if self._unambiguous_columns is None:
 859            self._unambiguous_columns = self._get_unambiguous_columns(
 860                self._get_all_source_columns()
 861            )
 862
 863        table_name = self._unambiguous_columns.get(column_name)
 864
 865        if not table_name and self._infer_schema:
 866            sources_without_schema = tuple(
 867                source
 868                for source, columns in self._get_all_source_columns().items()
 869                if not columns or "*" in columns
 870            )
 871            if len(sources_without_schema) == 1:
 872                table_name = sources_without_schema[0]
 873
 874        if table_name not in self.scope.selected_sources:
 875            return exp.to_identifier(table_name)
 876
 877        node, _ = self.scope.selected_sources.get(table_name)
 878
 879        if isinstance(node, exp.Query):
 880            while node and node.alias != table_name:
 881                node = node.parent
 882
 883        node_alias = node.args.get("alias")
 884        if node_alias:
 885            return exp.to_identifier(node_alias.this)
 886
 887        return exp.to_identifier(table_name)
 888
 889    @property
 890    def all_columns(self) -> t.Set[str]:
 891        """All available columns of all sources in this scope"""
 892        if self._all_columns is None:
 893            self._all_columns = {
 894                column for columns in self._get_all_source_columns().values() for column in columns
 895            }
 896        return self._all_columns
 897
 898    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 899        """Resolve the source columns for a given source `name`."""
 900        cache_key = (name, only_visible)
 901        if cache_key not in self._get_source_columns_cache:
 902            if name not in self.scope.sources:
 903                raise OptimizeError(f"Unknown table: {name}")
 904
 905            source = self.scope.sources[name]
 906
 907            if isinstance(source, exp.Table):
 908                columns = self.schema.column_names(source, only_visible)
 909            elif isinstance(source, Scope) and isinstance(
 910                source.expression, (exp.Values, exp.Unnest)
 911            ):
 912                columns = source.expression.named_selects
 913
 914                # in bigquery, unnest structs are automatically scoped as tables, so you can
 915                # directly select a struct field in a query.
 916                # this handles the case where the unnest is statically defined.
 917                if self.schema.dialect == "bigquery":
 918                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 919                        for k in source.expression.type.expressions:  # type: ignore
 920                            columns.append(k.name)
 921            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
 922                set_op = source.expression
 923
 924                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
 925                on_column_list = set_op.args.get("on")
 926
 927                if on_column_list:
 928                    # The resulting columns are the columns in the ON clause:
 929                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
 930                    columns = [col.name for col in on_column_list]
 931                elif set_op.side or set_op.kind:
 932                    side = set_op.side
 933                    kind = set_op.kind
 934
 935                    left = set_op.left.named_selects
 936                    right = set_op.right.named_selects
 937
 938                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
 939                    if side == "LEFT":
 940                        columns = left
 941                    elif side == "FULL":
 942                        columns = list(dict.fromkeys(left + right))
 943                    elif kind == "INNER":
 944                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
 945                else:
 946                    columns = set_op.named_selects
 947            else:
 948                columns = source.expression.named_selects
 949
 950            node, _ = self.scope.selected_sources.get(name) or (None, None)
 951            if isinstance(node, Scope):
 952                column_aliases = node.expression.alias_column_names
 953            elif isinstance(node, exp.Expression):
 954                column_aliases = node.alias_column_names
 955            else:
 956                column_aliases = []
 957
 958            if column_aliases:
 959                # If the source's columns are aliased, their aliases shadow the corresponding column names.
 960                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
 961                columns = [
 962                    alias or name
 963                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
 964                ]
 965
 966            self._get_source_columns_cache[cache_key] = columns
 967
 968        return self._get_source_columns_cache[cache_key]
 969
 970    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
 971        if self._source_columns is None:
 972            self._source_columns = {
 973                source_name: self.get_source_columns(source_name)
 974                for source_name, source in itertools.chain(
 975                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
 976                )
 977            }
 978        return self._source_columns
 979
 980    def _get_unambiguous_columns(
 981        self, source_columns: t.Dict[str, t.Sequence[str]]
 982    ) -> t.Mapping[str, str]:
 983        """
 984        Find all the unambiguous columns in sources.
 985
 986        Args:
 987            source_columns: Mapping of names to source columns.
 988
 989        Returns:
 990            Mapping of column name to source name.
 991        """
 992        if not source_columns:
 993            return {}
 994
 995        source_columns_pairs = list(source_columns.items())
 996
 997        first_table, first_columns = source_columns_pairs[0]
 998
 999        if len(source_columns_pairs) == 1:
1000            # Performance optimization - avoid copying first_columns if there is only one table.
1001            return SingleValuedMapping(first_columns, first_table)
1002
1003        unambiguous_columns = {col: first_table for col in first_columns}
1004        all_columns = set(unambiguous_columns)
1005
1006        for table, columns in source_columns_pairs[1:]:
1007            unique = set(columns)
1008            ambiguous = all_columns.intersection(unique)
1009            all_columns.update(columns)
1010
1011            for column in ambiguous:
1012                unambiguous_columns.pop(column, None)
1013            for column in unique.difference(ambiguous):
1014                unambiguous_columns[column] = table
1015
1016        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27) -> exp.Expression:
 28    """
 29    Rewrite sqlglot AST to have fully qualified columns.
 30
 31    Example:
 32        >>> import sqlglot
 33        >>> schema = {"tbl": {"col": "INT"}}
 34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 35        >>> qualify_columns(expression, schema).sql()
 36        'SELECT tbl.col AS col FROM tbl'
 37
 38    Args:
 39        expression: Expression to qualify.
 40        schema: Database schema.
 41        expand_alias_refs: Whether to expand references to aliases.
 42        expand_stars: Whether to expand star queries. This is a necessary step
 43            for most of the optimizer's rules to work; do not set to False unless you
 44            know what you're doing!
 45        infer_schema: Whether to infer the schema if missing.
 46        allow_partial_qualification: Whether to allow partial qualification.
 47
 48    Returns:
 49        The qualified expression.
 50
 51    Notes:
 52        - Currently only handles a single PIVOT or UNPIVOT operator
 53    """
 54    schema = ensure_schema(schema)
 55    annotator = TypeAnnotator(schema)
 56    infer_schema = schema.empty if infer_schema is None else infer_schema
 57    dialect = Dialect.get_or_raise(schema.dialect)
 58    pseudocolumns = dialect.PSEUDOCOLUMNS
 59    bigquery = dialect == "bigquery"
 60
 61    for scope in traverse_scope(expression):
 62        scope_expression = scope.expression
 63        is_select = isinstance(scope_expression, exp.Select)
 64
 65        if is_select and scope_expression.args.get("connect"):
 66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 68            scope_expression.transform(
 69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 70                copy=False,
 71            )
 72            scope.clear_cache()
 73
 74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 75        _pop_table_column_aliases(scope.ctes)
 76        _pop_table_column_aliases(scope.derived_tables)
 77        using_column_tables = _expand_using(scope, resolver)
 78
 79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 80            _expand_alias_refs(
 81                scope,
 82                resolver,
 83                dialect,
 84                expand_only_groupby=bigquery,
 85            )
 86
 87        _convert_columns_to_dots(scope, resolver)
 88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 89
 90        if not schema.empty and expand_alias_refs:
 91            _expand_alias_refs(scope, resolver, dialect)
 92
 93        if is_select:
 94            if expand_stars:
 95                _expand_stars(
 96                    scope,
 97                    resolver,
 98                    using_column_tables,
 99                    pseudocolumns,
100                    annotator,
101                )
102            qualify_outputs(scope)
103
104        _expand_group_by(scope, dialect)
105
106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
108        _expand_order_by_and_distinct_on(scope, resolver)
109
110        if bigquery:
111            annotator.annotate_scope(scope)
112
113    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
116def validate_qualify_columns(expression: E) -> E:
117    """Raise an `OptimizeError` if any columns aren't qualified"""
118    all_unqualified_columns = []
119    for scope in traverse_scope(expression):
120        if isinstance(scope.expression, exp.Select):
121            unqualified_columns = scope.unqualified_columns
122
123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
124                column = scope.external_columns[0]
125                for_table = f" for table: '{column.table}'" if column.table else ""
126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
127
128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
131                # this list here to ensure those in the former category will be excluded.
132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
134
135            all_unqualified_columns.extend(unqualified_columns)
136
137    if all_unqualified_columns:
138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
139
140    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
761def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
762    """Ensure all output columns are aliased"""
763    if isinstance(scope_or_expression, exp.Expression):
764        scope = build_scope(scope_or_expression)
765        if not isinstance(scope, Scope):
766            return
767    else:
768        scope = scope_or_expression
769
770    new_selections = []
771    for i, (selection, aliased_column) in enumerate(
772        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
773    ):
774        if selection is None:
775            break
776
777        if isinstance(selection, exp.Subquery):
778            if not selection.output_name:
779                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
780        elif not isinstance(selection, exp.Alias) and not selection.is_star:
781            selection = alias(
782                selection,
783                alias=selection.output_name or f"_col_{i}",
784                copy=False,
785            )
786        if aliased_column:
787            selection.set("alias", exp.to_identifier(aliased_column))
788
789        new_selections.append(selection)
790
791    if isinstance(scope.expression, exp.Select):
792        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
795def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
796    """Makes sure all identifiers that need to be quoted are quoted."""
797    return expression.transform(
798        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
799    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
802def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
803    """
804    Pushes down the CTE alias columns into the projection,
805
806    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
807
808    Example:
809        >>> import sqlglot
810        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
811        >>> pushdown_cte_alias_columns(expression).sql()
812        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
813
814    Args:
815        expression: Expression to pushdown.
816
817    Returns:
818        The expression with the CTE aliases pushed down into the projection.
819    """
820    for cte in expression.find_all(exp.CTE):
821        if cte.alias_column_names:
822            new_expressions = []
823            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
824                if isinstance(projection, exp.Alias):
825                    projection.set("alias", _alias)
826                else:
827                    projection = alias(projection, alias=_alias)
828                new_expressions.append(projection)
829            cte.this.set("expressions", new_expressions)
830
831    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
 834class Resolver:
 835    """
 836    Helper for resolving columns.
 837
 838    This is a class so we can lazily load some things and easily share them across functions.
 839    """
 840
 841    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 842        self.scope = scope
 843        self.schema = schema
 844        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 845        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 846        self._all_columns: t.Optional[t.Set[str]] = None
 847        self._infer_schema = infer_schema
 848        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 849
 850    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 851        """
 852        Get the table for a column name.
 853
 854        Args:
 855            column_name: The column name to find the table for.
 856        Returns:
 857            The table name if it can be found/inferred.
 858        """
 859        if self._unambiguous_columns is None:
 860            self._unambiguous_columns = self._get_unambiguous_columns(
 861                self._get_all_source_columns()
 862            )
 863
 864        table_name = self._unambiguous_columns.get(column_name)
 865
 866        if not table_name and self._infer_schema:
 867            sources_without_schema = tuple(
 868                source
 869                for source, columns in self._get_all_source_columns().items()
 870                if not columns or "*" in columns
 871            )
 872            if len(sources_without_schema) == 1:
 873                table_name = sources_without_schema[0]
 874
 875        if table_name not in self.scope.selected_sources:
 876            return exp.to_identifier(table_name)
 877
 878        node, _ = self.scope.selected_sources.get(table_name)
 879
 880        if isinstance(node, exp.Query):
 881            while node and node.alias != table_name:
 882                node = node.parent
 883
 884        node_alias = node.args.get("alias")
 885        if node_alias:
 886            return exp.to_identifier(node_alias.this)
 887
 888        return exp.to_identifier(table_name)
 889
 890    @property
 891    def all_columns(self) -> t.Set[str]:
 892        """All available columns of all sources in this scope"""
 893        if self._all_columns is None:
 894            self._all_columns = {
 895                column for columns in self._get_all_source_columns().values() for column in columns
 896            }
 897        return self._all_columns
 898
 899    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 900        """Resolve the source columns for a given source `name`."""
 901        cache_key = (name, only_visible)
 902        if cache_key not in self._get_source_columns_cache:
 903            if name not in self.scope.sources:
 904                raise OptimizeError(f"Unknown table: {name}")
 905
 906            source = self.scope.sources[name]
 907
 908            if isinstance(source, exp.Table):
 909                columns = self.schema.column_names(source, only_visible)
 910            elif isinstance(source, Scope) and isinstance(
 911                source.expression, (exp.Values, exp.Unnest)
 912            ):
 913                columns = source.expression.named_selects
 914
 915                # in bigquery, unnest structs are automatically scoped as tables, so you can
 916                # directly select a struct field in a query.
 917                # this handles the case where the unnest is statically defined.
 918                if self.schema.dialect == "bigquery":
 919                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 920                        for k in source.expression.type.expressions:  # type: ignore
 921                            columns.append(k.name)
 922            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
 923                set_op = source.expression
 924
 925                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
 926                on_column_list = set_op.args.get("on")
 927
 928                if on_column_list:
 929                    # The resulting columns are the columns in the ON clause:
 930                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
 931                    columns = [col.name for col in on_column_list]
 932                elif set_op.side or set_op.kind:
 933                    side = set_op.side
 934                    kind = set_op.kind
 935
 936                    left = set_op.left.named_selects
 937                    right = set_op.right.named_selects
 938
 939                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
 940                    if side == "LEFT":
 941                        columns = left
 942                    elif side == "FULL":
 943                        columns = list(dict.fromkeys(left + right))
 944                    elif kind == "INNER":
 945                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
 946                else:
 947                    columns = set_op.named_selects
 948            else:
 949                columns = source.expression.named_selects
 950
 951            node, _ = self.scope.selected_sources.get(name) or (None, None)
 952            if isinstance(node, Scope):
 953                column_aliases = node.expression.alias_column_names
 954            elif isinstance(node, exp.Expression):
 955                column_aliases = node.alias_column_names
 956            else:
 957                column_aliases = []
 958
 959            if column_aliases:
 960                # If the source's columns are aliased, their aliases shadow the corresponding column names.
 961                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
 962                columns = [
 963                    alias or name
 964                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
 965                ]
 966
 967            self._get_source_columns_cache[cache_key] = columns
 968
 969        return self._get_source_columns_cache[cache_key]
 970
 971    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
 972        if self._source_columns is None:
 973            self._source_columns = {
 974                source_name: self.get_source_columns(source_name)
 975                for source_name, source in itertools.chain(
 976                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
 977                )
 978            }
 979        return self._source_columns
 980
 981    def _get_unambiguous_columns(
 982        self, source_columns: t.Dict[str, t.Sequence[str]]
 983    ) -> t.Mapping[str, str]:
 984        """
 985        Find all the unambiguous columns in sources.
 986
 987        Args:
 988            source_columns: Mapping of names to source columns.
 989
 990        Returns:
 991            Mapping of column name to source name.
 992        """
 993        if not source_columns:
 994            return {}
 995
 996        source_columns_pairs = list(source_columns.items())
 997
 998        first_table, first_columns = source_columns_pairs[0]
 999
1000        if len(source_columns_pairs) == 1:
1001            # Performance optimization - avoid copying first_columns if there is only one table.
1002            return SingleValuedMapping(first_columns, first_table)
1003
1004        unambiguous_columns = {col: first_table for col in first_columns}
1005        all_columns = set(unambiguous_columns)
1006
1007        for table, columns in source_columns_pairs[1:]:
1008            unique = set(columns)
1009            ambiguous = all_columns.intersection(unique)
1010            all_columns.update(columns)
1011
1012            for column in ambiguous:
1013                unambiguous_columns.pop(column, None)
1014            for column in unique.difference(ambiguous):
1015                unambiguous_columns[column] = table
1016
1017        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
841    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
842        self.scope = scope
843        self.schema = schema
844        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
845        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
846        self._all_columns: t.Optional[t.Set[str]] = None
847        self._infer_schema = infer_schema
848        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
850    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
851        """
852        Get the table for a column name.
853
854        Args:
855            column_name: The column name to find the table for.
856        Returns:
857            The table name if it can be found/inferred.
858        """
859        if self._unambiguous_columns is None:
860            self._unambiguous_columns = self._get_unambiguous_columns(
861                self._get_all_source_columns()
862            )
863
864        table_name = self._unambiguous_columns.get(column_name)
865
866        if not table_name and self._infer_schema:
867            sources_without_schema = tuple(
868                source
869                for source, columns in self._get_all_source_columns().items()
870                if not columns or "*" in columns
871            )
872            if len(sources_without_schema) == 1:
873                table_name = sources_without_schema[0]
874
875        if table_name not in self.scope.selected_sources:
876            return exp.to_identifier(table_name)
877
878        node, _ = self.scope.selected_sources.get(table_name)
879
880        if isinstance(node, exp.Query):
881            while node and node.alias != table_name:
882                node = node.parent
883
884        node_alias = node.args.get("alias")
885        if node_alias:
886            return exp.to_identifier(node_alias.this)
887
888        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
890    @property
891    def all_columns(self) -> t.Set[str]:
892        """All available columns of all sources in this scope"""
893        if self._all_columns is None:
894            self._all_columns = {
895                column for columns in self._get_all_source_columns().values() for column in columns
896            }
897        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
899    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
900        """Resolve the source columns for a given source `name`."""
901        cache_key = (name, only_visible)
902        if cache_key not in self._get_source_columns_cache:
903            if name not in self.scope.sources:
904                raise OptimizeError(f"Unknown table: {name}")
905
906            source = self.scope.sources[name]
907
908            if isinstance(source, exp.Table):
909                columns = self.schema.column_names(source, only_visible)
910            elif isinstance(source, Scope) and isinstance(
911                source.expression, (exp.Values, exp.Unnest)
912            ):
913                columns = source.expression.named_selects
914
915                # in bigquery, unnest structs are automatically scoped as tables, so you can
916                # directly select a struct field in a query.
917                # this handles the case where the unnest is statically defined.
918                if self.schema.dialect == "bigquery":
919                    if source.expression.is_type(exp.DataType.Type.STRUCT):
920                        for k in source.expression.type.expressions:  # type: ignore
921                            columns.append(k.name)
922            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
923                set_op = source.expression
924
925                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
926                on_column_list = set_op.args.get("on")
927
928                if on_column_list:
929                    # The resulting columns are the columns in the ON clause:
930                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
931                    columns = [col.name for col in on_column_list]
932                elif set_op.side or set_op.kind:
933                    side = set_op.side
934                    kind = set_op.kind
935
936                    left = set_op.left.named_selects
937                    right = set_op.right.named_selects
938
939                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
940                    if side == "LEFT":
941                        columns = left
942                    elif side == "FULL":
943                        columns = list(dict.fromkeys(left + right))
944                    elif kind == "INNER":
945                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
946                else:
947                    columns = set_op.named_selects
948            else:
949                columns = source.expression.named_selects
950
951            node, _ = self.scope.selected_sources.get(name) or (None, None)
952            if isinstance(node, Scope):
953                column_aliases = node.expression.alias_column_names
954            elif isinstance(node, exp.Expression):
955                column_aliases = node.alias_column_names
956            else:
957                column_aliases = []
958
959            if column_aliases:
960                # If the source's columns are aliased, their aliases shadow the corresponding column names.
961                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
962                columns = [
963                    alias or name
964                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
965                ]
966
967            self._get_source_columns_cache[cache_key] = columns
968
969        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.