Edit on GitHub

sqlglot.optimizer.qualify_columns

   1from __future__ import annotations
   2
   3import itertools
   4import typing as t
   5
   6from sqlglot import alias, exp
   7from sqlglot.dialects.dialect import Dialect, DialectType
   8from sqlglot.errors import OptimizeError
   9from sqlglot.helper import seq_get, SingleValuedMapping
  10from sqlglot.optimizer.annotate_types import TypeAnnotator
  11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
  12from sqlglot.optimizer.simplify import simplify_parens
  13from sqlglot.schema import Schema, ensure_schema
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot._typing import E
  17
  18
  19def qualify_columns(
  20    expression: exp.Expression,
  21    schema: t.Dict | Schema,
  22    expand_alias_refs: bool = True,
  23    expand_stars: bool = True,
  24    infer_schema: t.Optional[bool] = None,
  25    allow_partial_qualification: bool = False,
  26    dialect: DialectType = None,
  27) -> exp.Expression:
  28    """
  29    Rewrite sqlglot AST to have fully qualified columns.
  30
  31    Example:
  32        >>> import sqlglot
  33        >>> schema = {"tbl": {"col": "INT"}}
  34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
  35        >>> qualify_columns(expression, schema).sql()
  36        'SELECT tbl.col AS col FROM tbl'
  37
  38    Args:
  39        expression: Expression to qualify.
  40        schema: Database schema.
  41        expand_alias_refs: Whether to expand references to aliases.
  42        expand_stars: Whether to expand star queries. This is a necessary step
  43            for most of the optimizer's rules to work; do not set to False unless you
  44            know what you're doing!
  45        infer_schema: Whether to infer the schema if missing.
  46        allow_partial_qualification: Whether to allow partial qualification.
  47
  48    Returns:
  49        The qualified expression.
  50
  51    Notes:
  52        - Currently only handles a single PIVOT or UNPIVOT operator
  53    """
  54    schema = ensure_schema(schema, dialect=dialect)
  55    annotator = TypeAnnotator(schema)
  56    infer_schema = schema.empty if infer_schema is None else infer_schema
  57    dialect = Dialect.get_or_raise(schema.dialect)
  58    pseudocolumns = dialect.PSEUDOCOLUMNS
  59    bigquery = dialect == "bigquery"
  60
  61    for scope in traverse_scope(expression):
  62        scope_expression = scope.expression
  63        is_select = isinstance(scope_expression, exp.Select)
  64
  65        if is_select and scope_expression.args.get("connect"):
  66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
  67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
  68            scope_expression.transform(
  69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
  70                copy=False,
  71            )
  72            scope.clear_cache()
  73
  74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
  75        _pop_table_column_aliases(scope.ctes)
  76        _pop_table_column_aliases(scope.derived_tables)
  77        using_column_tables = _expand_using(scope, resolver)
  78
  79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
  80            _expand_alias_refs(
  81                scope,
  82                resolver,
  83                dialect,
  84                expand_only_groupby=bigquery,
  85            )
  86
  87        _convert_columns_to_dots(scope, resolver)
  88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
  89
  90        if not schema.empty and expand_alias_refs:
  91            _expand_alias_refs(scope, resolver, dialect)
  92
  93        if is_select:
  94            if expand_stars:
  95                _expand_stars(
  96                    scope,
  97                    resolver,
  98                    using_column_tables,
  99                    pseudocolumns,
 100                    annotator,
 101                )
 102            qualify_outputs(scope)
 103
 104        _expand_group_by(scope, dialect)
 105
 106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
 107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
 108        _expand_order_by_and_distinct_on(scope, resolver)
 109
 110        if bigquery:
 111            annotator.annotate_scope(scope)
 112
 113    return expression
 114
 115
 116def validate_qualify_columns(expression: E) -> E:
 117    """Raise an `OptimizeError` if any columns aren't qualified"""
 118    all_unqualified_columns = []
 119    for scope in traverse_scope(expression):
 120        if isinstance(scope.expression, exp.Select):
 121            unqualified_columns = scope.unqualified_columns
 122
 123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 124                column = scope.external_columns[0]
 125                for_table = f" for table: '{column.table}'" if column.table else ""
 126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 127
 128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 131                # this list here to ensure those in the former category will be excluded.
 132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 134
 135            all_unqualified_columns.extend(unqualified_columns)
 136
 137    if all_unqualified_columns:
 138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
 139
 140    return expression
 141
 142
 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
 144    name_columns = [
 145        field.this
 146        for field in unpivot.fields
 147        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
 148    ]
 149    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
 150
 151    return itertools.chain(name_columns, value_columns)
 152
 153
 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 155    """
 156    Remove table column aliases.
 157
 158    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 159    """
 160    for derived_table in derived_tables:
 161        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
 162            continue
 163        table_alias = derived_table.args.get("alias")
 164        if table_alias:
 165            table_alias.args.pop("columns", None)
 166
 167
 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 169    columns = {}
 170
 171    def _update_source_columns(source_name: str) -> None:
 172        for column_name in resolver.get_source_columns(source_name):
 173            if column_name not in columns:
 174                columns[column_name] = source_name
 175
 176    joins = list(scope.find_all(exp.Join))
 177    names = {join.alias_or_name for join in joins}
 178    ordered = [key for key in scope.selected_sources if key not in names]
 179
 180    if names and not ordered:
 181        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
 182
 183    # Mapping of automatically joined column names to an ordered set of source names (dict).
 184    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
 185
 186    for source_name in ordered:
 187        _update_source_columns(source_name)
 188
 189    for i, join in enumerate(joins):
 190        source_table = ordered[-1]
 191        if source_table:
 192            _update_source_columns(source_table)
 193
 194        join_table = join.alias_or_name
 195        ordered.append(join_table)
 196
 197        using = join.args.get("using")
 198        if not using:
 199            continue
 200
 201        join_columns = resolver.get_source_columns(join_table)
 202        conditions = []
 203        using_identifier_count = len(using)
 204        is_semi_or_anti_join = join.is_semi_or_anti_join
 205
 206        for identifier in using:
 207            identifier = identifier.name
 208            table = columns.get(identifier)
 209
 210            if not table or identifier not in join_columns:
 211                if (columns and "*" not in columns) and join_columns:
 212                    raise OptimizeError(f"Cannot automatically join: {identifier}")
 213
 214            table = table or source_table
 215
 216            if i == 0 or using_identifier_count == 1:
 217                lhs: exp.Expression = exp.column(identifier, table=table)
 218            else:
 219                coalesce_columns = [
 220                    exp.column(identifier, table=t)
 221                    for t in ordered[:-1]
 222                    if identifier in resolver.get_source_columns(t)
 223                ]
 224                if len(coalesce_columns) > 1:
 225                    lhs = exp.func("coalesce", *coalesce_columns)
 226                else:
 227                    lhs = exp.column(identifier, table=table)
 228
 229            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
 230
 231            # Set all values in the dict to None, because we only care about the key ordering
 232            tables = column_tables.setdefault(identifier, {})
 233
 234            # Do not update the dict if this was a SEMI/ANTI join in
 235            # order to avoid generating COALESCE columns for this join pair
 236            if not is_semi_or_anti_join:
 237                if table not in tables:
 238                    tables[table] = None
 239                if join_table not in tables:
 240                    tables[join_table] = None
 241
 242        join.args.pop("using")
 243        join.set("on", exp.and_(*conditions, copy=False))
 244
 245    if column_tables:
 246        for column in scope.columns:
 247            if not column.table and column.name in column_tables:
 248                tables = column_tables[column.name]
 249                coalesce_args = [exp.column(column.name, table=table) for table in tables]
 250                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
 251
 252                if isinstance(column.parent, exp.Select):
 253                    # Ensure the USING column keeps its name if it's projected
 254                    replacement = alias(replacement, alias=column.name, copy=False)
 255                elif isinstance(column.parent, exp.Struct):
 256                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
 257                    replacement = exp.PropertyEQ(
 258                        this=exp.to_identifier(column.name), expression=replacement
 259                    )
 260
 261                scope.replace(column, replacement)
 262
 263    return column_tables
 264
 265
 266def _expand_alias_refs(
 267    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
 268) -> None:
 269    """
 270    Expand references to aliases.
 271    Example:
 272        SELECT y.foo AS bar, bar * 2 AS baz FROM y
 273     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
 274    """
 275    expression = scope.expression
 276
 277    if not isinstance(expression, exp.Select) or dialect == "oracle":
 278        return
 279
 280    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
 281    projections = {s.alias_or_name for s in expression.selects}
 282
 283    def replace_columns(
 284        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
 285    ) -> None:
 286        is_group_by = isinstance(node, exp.Group)
 287        is_having = isinstance(node, exp.Having)
 288        if not node or (expand_only_groupby and not is_group_by):
 289            return
 290
 291        for column in walk_in_scope(node, prune=lambda node: node.is_star):
 292            if not isinstance(column, exp.Column):
 293                continue
 294
 295            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
 296            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
 297            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
 298            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
 299            if expand_only_groupby and is_group_by and column.parent is not node:
 300                continue
 301
 302            skip_replace = False
 303            table = resolver.get_table(column.name) if resolve_table and not column.table else None
 304            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
 305
 306            if alias_expr:
 307                skip_replace = bool(
 308                    alias_expr.find(exp.AggFunc)
 309                    and column.find_ancestor(exp.AggFunc)
 310                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
 311                )
 312
 313                # BigQuery's having clause gets confused if an alias matches a source.
 314                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
 315                # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table
 316                if is_having and dialect == "bigquery":
 317                    skip_replace = skip_replace or any(
 318                        node.parts[0].name in projections
 319                        for node in alias_expr.find_all(exp.Column)
 320                    )
 321
 322            if table and (not alias_expr or skip_replace):
 323                column.set("table", table)
 324            elif not column.table and alias_expr and not skip_replace:
 325                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
 326                    if literal_index:
 327                        column.replace(exp.Literal.number(i))
 328                else:
 329                    column = column.replace(exp.paren(alias_expr))
 330                    simplified = simplify_parens(column, dialect)
 331                    if simplified is not column:
 332                        column.replace(simplified)
 333
 334    for i, projection in enumerate(expression.selects):
 335        replace_columns(projection)
 336        if isinstance(projection, exp.Alias):
 337            alias_to_expression[projection.alias] = (projection.this, i + 1)
 338
 339    parent_scope = scope
 340    while parent_scope.is_union:
 341        parent_scope = parent_scope.parent
 342
 343    # We shouldn't expand aliases if they match the recursive CTE's columns
 344    if parent_scope.is_cte:
 345        cte = parent_scope.expression.parent
 346        if cte.find_ancestor(exp.With).recursive:
 347            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
 348                alias_to_expression.pop(recursive_cte_column.output_name, None)
 349
 350    replace_columns(expression.args.get("where"))
 351    replace_columns(expression.args.get("group"), literal_index=True)
 352    replace_columns(expression.args.get("having"), resolve_table=True)
 353    replace_columns(expression.args.get("qualify"), resolve_table=True)
 354
 355    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
 356    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
 357    if dialect == "snowflake":
 358        for join in expression.args.get("joins") or []:
 359            replace_columns(join)
 360
 361    scope.clear_cache()
 362
 363
 364def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
 365    expression = scope.expression
 366    group = expression.args.get("group")
 367    if not group:
 368        return
 369
 370    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
 371    expression.set("group", group)
 372
 373
 374def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
 375    for modifier_key in ("order", "distinct"):
 376        modifier = scope.expression.args.get(modifier_key)
 377        if isinstance(modifier, exp.Distinct):
 378            modifier = modifier.args.get("on")
 379
 380        if not isinstance(modifier, exp.Expression):
 381            continue
 382
 383        modifier_expressions = modifier.expressions
 384        if modifier_key == "order":
 385            modifier_expressions = [ordered.this for ordered in modifier_expressions]
 386
 387        for original, expanded in zip(
 388            modifier_expressions,
 389            _expand_positional_references(
 390                scope, modifier_expressions, resolver.schema.dialect, alias=True
 391            ),
 392        ):
 393            for agg in original.find_all(exp.AggFunc):
 394                for col in agg.find_all(exp.Column):
 395                    if not col.table:
 396                        col.set("table", resolver.get_table(col.name))
 397
 398            original.replace(expanded)
 399
 400        if scope.expression.args.get("group"):
 401            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
 402
 403            for expression in modifier_expressions:
 404                expression.replace(
 405                    exp.to_identifier(_select_by_pos(scope, expression).alias)
 406                    if expression.is_int
 407                    else selects.get(expression, expression)
 408                )
 409
 410
 411def _expand_positional_references(
 412    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
 413) -> t.List[exp.Expression]:
 414    new_nodes: t.List[exp.Expression] = []
 415    ambiguous_projections = None
 416
 417    for node in expressions:
 418        if node.is_int:
 419            select = _select_by_pos(scope, t.cast(exp.Literal, node))
 420
 421            if alias:
 422                new_nodes.append(exp.column(select.args["alias"].copy()))
 423            else:
 424                select = select.this
 425
 426                if dialect == "bigquery":
 427                    if ambiguous_projections is None:
 428                        # When a projection name is also a source name and it is referenced in the
 429                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
 430                        ambiguous_projections = {
 431                            s.alias_or_name
 432                            for s in scope.expression.selects
 433                            if s.alias_or_name in scope.selected_sources
 434                        }
 435
 436                    ambiguous = any(
 437                        column.parts[0].name in ambiguous_projections
 438                        for column in select.find_all(exp.Column)
 439                    )
 440                else:
 441                    ambiguous = False
 442
 443                if (
 444                    isinstance(select, exp.CONSTANTS)
 445                    or select.find(exp.Explode, exp.Unnest)
 446                    or ambiguous
 447                ):
 448                    new_nodes.append(node)
 449                else:
 450                    new_nodes.append(select.copy())
 451        else:
 452            new_nodes.append(node)
 453
 454    return new_nodes
 455
 456
 457def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
 458    try:
 459        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
 460    except IndexError:
 461        raise OptimizeError(f"Unknown output column: {node.name}")
 462
 463
 464def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
 465    """
 466    Converts `Column` instances that represent struct field lookup into chained `Dots`.
 467
 468    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
 469    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
 470    """
 471    converted = False
 472    for column in itertools.chain(scope.columns, scope.stars):
 473        if isinstance(column, exp.Dot):
 474            continue
 475
 476        column_table: t.Optional[str | exp.Identifier] = column.table
 477        if (
 478            column_table
 479            and column_table not in scope.sources
 480            and (
 481                not scope.parent
 482                or column_table not in scope.parent.sources
 483                or not scope.is_correlated_subquery
 484            )
 485        ):
 486            root, *parts = column.parts
 487
 488            if root.name in scope.sources:
 489                # The struct is already qualified, but we still need to change the AST
 490                column_table = root
 491                root, *parts = parts
 492            else:
 493                column_table = resolver.get_table(root.name)
 494
 495            if column_table:
 496                converted = True
 497                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
 498
 499    if converted:
 500        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
 501        # a `for column in scope.columns` iteration, even though they shouldn't be
 502        scope.clear_cache()
 503
 504
 505def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
 506    """Disambiguate columns, ensuring each column specifies a source"""
 507    for column in scope.columns:
 508        column_table = column.table
 509        column_name = column.name
 510
 511        if column_table and column_table in scope.sources:
 512            source_columns = resolver.get_source_columns(column_table)
 513            if (
 514                not allow_partial_qualification
 515                and source_columns
 516                and column_name not in source_columns
 517                and "*" not in source_columns
 518            ):
 519                raise OptimizeError(f"Unknown column: {column_name}")
 520
 521        if not column_table:
 522            if scope.pivots and not column.find_ancestor(exp.Pivot):
 523                # If the column is under the Pivot expression, we need to qualify it
 524                # using the name of the pivoted source instead of the pivot's alias
 525                column.set("table", exp.to_identifier(scope.pivots[0].alias))
 526                continue
 527
 528            # column_table can be a '' because bigquery unnest has no table alias
 529            column_table = resolver.get_table(column_name)
 530            if column_table:
 531                column.set("table", column_table)
 532            elif (
 533                resolver.schema.dialect == "bigquery"
 534                and len(column.parts) == 1
 535                and column_name in scope.selected_sources
 536            ):
 537                # BigQuery allows tables to be referenced as columns, treating them as structs
 538                scope.replace(column, exp.TableColumn(this=column.this))
 539
 540    for pivot in scope.pivots:
 541        for column in pivot.find_all(exp.Column):
 542            if not column.table and column.name in resolver.all_columns:
 543                column_table = resolver.get_table(column.name)
 544                if column_table:
 545                    column.set("table", column_table)
 546
 547
 548def _expand_struct_stars_bigquery(
 549    expression: exp.Dot,
 550) -> t.List[exp.Alias]:
 551    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
 552
 553    dot_column = expression.find(exp.Column)
 554    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
 555        return []
 556
 557    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
 558    dot_column = dot_column.copy()
 559    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
 560
 561    # First part is the table name and last part is the star so they can be dropped
 562    dot_parts = expression.parts[1:-1]
 563
 564    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
 565    for part in dot_parts[1:]:
 566        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 567            # Unable to expand star unless all fields are named
 568            if not isinstance(field.this, exp.Identifier):
 569                return []
 570
 571            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
 572                starting_struct = field
 573                break
 574        else:
 575            # There is no matching field in the struct
 576            return []
 577
 578    taken_names = set()
 579    new_selections = []
 580
 581    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 582        name = field.name
 583
 584        # Ambiguous or anonymous fields can't be expanded
 585        if name in taken_names or not isinstance(field.this, exp.Identifier):
 586            return []
 587
 588        taken_names.add(name)
 589
 590        this = field.this.copy()
 591        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
 592        new_column = exp.column(
 593            t.cast(exp.Identifier, root),
 594            table=dot_column.args.get("table"),
 595            fields=t.cast(t.List[exp.Identifier], parts),
 596        )
 597        new_selections.append(alias(new_column, this, copy=False))
 598
 599    return new_selections
 600
 601
 602def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]:
 603    """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
 604
 605    # it is not (<sub_exp>).* pattern, which means we can't expand
 606    if not isinstance(expression.this, exp.Paren):
 607        return []
 608
 609    # find column definition to get data-type
 610    dot_column = expression.find(exp.Column)
 611    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
 612        return []
 613
 614    parent = dot_column.parent
 615    starting_struct = dot_column.type
 616
 617    # walk up AST and down into struct definition in sync
 618    while parent is not None:
 619        if isinstance(parent, exp.Paren):
 620            parent = parent.parent
 621            continue
 622
 623        # if parent is not a dot, then something is wrong
 624        if not isinstance(parent, exp.Dot):
 625            return []
 626
 627        # if the rhs of the dot is star we are done
 628        rhs = parent.right
 629        if isinstance(rhs, exp.Star):
 630            break
 631
 632        # if it is not identifier, then something is wrong
 633        if not isinstance(rhs, exp.Identifier):
 634            return []
 635
 636        # Check if current rhs identifier is in struct
 637        matched = False
 638        for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 639            if struct_field_def.name == rhs.name:
 640                matched = True
 641                starting_struct = struct_field_def.kind  # update struct
 642                break
 643
 644        if not matched:
 645            return []
 646
 647        parent = parent.parent
 648
 649    # build new aliases to expand star
 650    new_selections = []
 651
 652    # fetch the outermost parentheses for new aliaes
 653    outer_paren = expression.this
 654
 655    for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 656        new_identifier = struct_field_def.this.copy()
 657        new_dot = exp.Dot.build([outer_paren.copy(), new_identifier])
 658        new_alias = alias(new_dot, new_identifier, copy=False)
 659        new_selections.append(new_alias)
 660
 661    return new_selections
 662
 663
 664def _expand_stars(
 665    scope: Scope,
 666    resolver: Resolver,
 667    using_column_tables: t.Dict[str, t.Any],
 668    pseudocolumns: t.Set[str],
 669    annotator: TypeAnnotator,
 670) -> None:
 671    """Expand stars to lists of column selections"""
 672
 673    new_selections: t.List[exp.Expression] = []
 674    except_columns: t.Dict[int, t.Set[str]] = {}
 675    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
 676    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
 677
 678    coalesced_columns = set()
 679    dialect = resolver.schema.dialect
 680
 681    pivot_output_columns = None
 682    pivot_exclude_columns: t.Set[str] = set()
 683
 684    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
 685    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
 686        if pivot.unpivot:
 687            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
 688
 689            for field in pivot.fields:
 690                if isinstance(field, exp.In):
 691                    pivot_exclude_columns.update(
 692                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
 693                    )
 694
 695        else:
 696            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
 697
 698            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
 699            if not pivot_output_columns:
 700                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
 701
 702    is_bigquery = dialect == "bigquery"
 703    is_risingwave = dialect == "risingwave"
 704
 705    if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars):
 706        # Found struct expansion, annotate scope ahead of time
 707        annotator.annotate_scope(scope)
 708
 709    for expression in scope.expression.selects:
 710        tables = []
 711        if isinstance(expression, exp.Star):
 712            tables.extend(scope.selected_sources)
 713            _add_except_columns(expression, tables, except_columns)
 714            _add_replace_columns(expression, tables, replace_columns)
 715            _add_rename_columns(expression, tables, rename_columns)
 716        elif expression.is_star:
 717            if not isinstance(expression, exp.Dot):
 718                tables.append(expression.table)
 719                _add_except_columns(expression.this, tables, except_columns)
 720                _add_replace_columns(expression.this, tables, replace_columns)
 721                _add_rename_columns(expression.this, tables, rename_columns)
 722            elif is_bigquery:
 723                struct_fields = _expand_struct_stars_bigquery(expression)
 724                if struct_fields:
 725                    new_selections.extend(struct_fields)
 726                    continue
 727            elif is_risingwave:
 728                struct_fields = _expand_struct_stars_risingwave(expression)
 729                if struct_fields:
 730                    new_selections.extend(struct_fields)
 731                    continue
 732
 733        if not tables:
 734            new_selections.append(expression)
 735            continue
 736
 737        for table in tables:
 738            if table not in scope.sources:
 739                raise OptimizeError(f"Unknown table: {table}")
 740
 741            columns = resolver.get_source_columns(table, only_visible=True)
 742            columns = columns or scope.outer_columns
 743
 744            if pseudocolumns:
 745                columns = [name for name in columns if name.upper() not in pseudocolumns]
 746
 747            if not columns or "*" in columns:
 748                return
 749
 750            table_id = id(table)
 751            columns_to_exclude = except_columns.get(table_id) or set()
 752            renamed_columns = rename_columns.get(table_id, {})
 753            replaced_columns = replace_columns.get(table_id, {})
 754
 755            if pivot:
 756                if pivot_output_columns and pivot_exclude_columns:
 757                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
 758                    pivot_columns.extend(pivot_output_columns)
 759                else:
 760                    pivot_columns = pivot.alias_column_names
 761
 762                if pivot_columns:
 763                    new_selections.extend(
 764                        alias(exp.column(name, table=pivot.alias), name, copy=False)
 765                        for name in pivot_columns
 766                        if name not in columns_to_exclude
 767                    )
 768                    continue
 769
 770            for name in columns:
 771                if name in columns_to_exclude or name in coalesced_columns:
 772                    continue
 773                if name in using_column_tables and table in using_column_tables[name]:
 774                    coalesced_columns.add(name)
 775                    tables = using_column_tables[name]
 776                    coalesce_args = [exp.column(name, table=table) for table in tables]
 777
 778                    new_selections.append(
 779                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
 780                    )
 781                else:
 782                    alias_ = renamed_columns.get(name, name)
 783                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
 784                    new_selections.append(
 785                        alias(selection_expr, alias_, copy=False)
 786                        if alias_ != name
 787                        else selection_expr
 788                    )
 789
 790    # Ensures we don't overwrite the initial selections with an empty list
 791    if new_selections and isinstance(scope.expression, exp.Select):
 792        scope.expression.set("expressions", new_selections)
 793
 794
 795def _add_except_columns(
 796    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
 797) -> None:
 798    except_ = expression.args.get("except")
 799
 800    if not except_:
 801        return
 802
 803    columns = {e.name for e in except_}
 804
 805    for table in tables:
 806        except_columns[id(table)] = columns
 807
 808
 809def _add_rename_columns(
 810    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
 811) -> None:
 812    rename = expression.args.get("rename")
 813
 814    if not rename:
 815        return
 816
 817    columns = {e.this.name: e.alias for e in rename}
 818
 819    for table in tables:
 820        rename_columns[id(table)] = columns
 821
 822
 823def _add_replace_columns(
 824    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
 825) -> None:
 826    replace = expression.args.get("replace")
 827
 828    if not replace:
 829        return
 830
 831    columns = {e.alias: e for e in replace}
 832
 833    for table in tables:
 834        replace_columns[id(table)] = columns
 835
 836
 837def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
 838    """Ensure all output columns are aliased"""
 839    if isinstance(scope_or_expression, exp.Expression):
 840        scope = build_scope(scope_or_expression)
 841        if not isinstance(scope, Scope):
 842            return
 843    else:
 844        scope = scope_or_expression
 845
 846    new_selections = []
 847    for i, (selection, aliased_column) in enumerate(
 848        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
 849    ):
 850        if selection is None or isinstance(selection, exp.QueryTransform):
 851            break
 852
 853        if isinstance(selection, exp.Subquery):
 854            if not selection.output_name:
 855                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
 856        elif not isinstance(selection, exp.Alias) and not selection.is_star:
 857            selection = alias(
 858                selection,
 859                alias=selection.output_name or f"_col_{i}",
 860                copy=False,
 861            )
 862        if aliased_column:
 863            selection.set("alias", exp.to_identifier(aliased_column))
 864
 865        new_selections.append(selection)
 866
 867    if new_selections and isinstance(scope.expression, exp.Select):
 868        scope.expression.set("expressions", new_selections)
 869
 870
 871def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
 872    """Makes sure all identifiers that need to be quoted are quoted."""
 873    return expression.transform(
 874        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
 875    )  # type: ignore
 876
 877
 878def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
 879    """
 880    Pushes down the CTE alias columns into the projection,
 881
 882    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
 883
 884    Example:
 885        >>> import sqlglot
 886        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
 887        >>> pushdown_cte_alias_columns(expression).sql()
 888        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
 889
 890    Args:
 891        expression: Expression to pushdown.
 892
 893    Returns:
 894        The expression with the CTE aliases pushed down into the projection.
 895    """
 896    for cte in expression.find_all(exp.CTE):
 897        if cte.alias_column_names:
 898            new_expressions = []
 899            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
 900                if isinstance(projection, exp.Alias):
 901                    projection.set("alias", _alias)
 902                else:
 903                    projection = alias(projection, alias=_alias)
 904                new_expressions.append(projection)
 905            cte.this.set("expressions", new_expressions)
 906
 907    return expression
 908
 909
 910class Resolver:
 911    """
 912    Helper for resolving columns.
 913
 914    This is a class so we can lazily load some things and easily share them across functions.
 915    """
 916
 917    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 918        self.scope = scope
 919        self.schema = schema
 920        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 921        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 922        self._all_columns: t.Optional[t.Set[str]] = None
 923        self._infer_schema = infer_schema
 924        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 925
 926    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 927        """
 928        Get the table for a column name.
 929
 930        Args:
 931            column_name: The column name to find the table for.
 932        Returns:
 933            The table name if it can be found/inferred.
 934        """
 935        if self._unambiguous_columns is None:
 936            self._unambiguous_columns = self._get_unambiguous_columns(
 937                self._get_all_source_columns()
 938            )
 939
 940        table_name = self._unambiguous_columns.get(column_name)
 941
 942        if not table_name and self._infer_schema:
 943            sources_without_schema = tuple(
 944                source
 945                for source, columns in self._get_all_source_columns().items()
 946                if not columns or "*" in columns
 947            )
 948            if len(sources_without_schema) == 1:
 949                table_name = sources_without_schema[0]
 950
 951        if table_name not in self.scope.selected_sources:
 952            return exp.to_identifier(table_name)
 953
 954        node, _ = self.scope.selected_sources.get(table_name)
 955
 956        if isinstance(node, exp.Query):
 957            while node and node.alias != table_name:
 958                node = node.parent
 959
 960        node_alias = node.args.get("alias")
 961        if node_alias:
 962            return exp.to_identifier(node_alias.this)
 963
 964        return exp.to_identifier(table_name)
 965
 966    @property
 967    def all_columns(self) -> t.Set[str]:
 968        """All available columns of all sources in this scope"""
 969        if self._all_columns is None:
 970            self._all_columns = {
 971                column for columns in self._get_all_source_columns().values() for column in columns
 972            }
 973        return self._all_columns
 974
 975    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 976        """Resolve the source columns for a given source `name`."""
 977        cache_key = (name, only_visible)
 978        if cache_key not in self._get_source_columns_cache:
 979            if name not in self.scope.sources:
 980                raise OptimizeError(f"Unknown table: {name}")
 981
 982            source = self.scope.sources[name]
 983
 984            if isinstance(source, exp.Table):
 985                columns = self.schema.column_names(source, only_visible)
 986            elif isinstance(source, Scope) and isinstance(
 987                source.expression, (exp.Values, exp.Unnest)
 988            ):
 989                columns = source.expression.named_selects
 990
 991                # in bigquery, unnest structs are automatically scoped as tables, so you can
 992                # directly select a struct field in a query.
 993                # this handles the case where the unnest is statically defined.
 994                if self.schema.dialect == "bigquery":
 995                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 996                        for k in source.expression.type.expressions:  # type: ignore
 997                            columns.append(k.name)
 998            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
 999                set_op = source.expression
1000
1001                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1002                on_column_list = set_op.args.get("on")
1003
1004                if on_column_list:
1005                    # The resulting columns are the columns in the ON clause:
1006                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1007                    columns = [col.name for col in on_column_list]
1008                elif set_op.side or set_op.kind:
1009                    side = set_op.side
1010                    kind = set_op.kind
1011
1012                    left = set_op.left.named_selects
1013                    right = set_op.right.named_selects
1014
1015                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
1016                    if side == "LEFT":
1017                        columns = left
1018                    elif side == "FULL":
1019                        columns = list(dict.fromkeys(left + right))
1020                    elif kind == "INNER":
1021                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1022                else:
1023                    columns = set_op.named_selects
1024            else:
1025                select = seq_get(source.expression.selects, 0)
1026
1027                if isinstance(select, exp.QueryTransform):
1028                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1029                    schema = select.args.get("schema")
1030                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1031                else:
1032                    columns = source.expression.named_selects
1033
1034            node, _ = self.scope.selected_sources.get(name) or (None, None)
1035            if isinstance(node, Scope):
1036                column_aliases = node.expression.alias_column_names
1037            elif isinstance(node, exp.Expression):
1038                column_aliases = node.alias_column_names
1039            else:
1040                column_aliases = []
1041
1042            if column_aliases:
1043                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1044                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1045                columns = [
1046                    alias or name
1047                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1048                ]
1049
1050            self._get_source_columns_cache[cache_key] = columns
1051
1052        return self._get_source_columns_cache[cache_key]
1053
1054    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1055        if self._source_columns is None:
1056            self._source_columns = {
1057                source_name: self.get_source_columns(source_name)
1058                for source_name, source in itertools.chain(
1059                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1060                )
1061            }
1062        return self._source_columns
1063
1064    def _get_unambiguous_columns(
1065        self, source_columns: t.Dict[str, t.Sequence[str]]
1066    ) -> t.Mapping[str, str]:
1067        """
1068        Find all the unambiguous columns in sources.
1069
1070        Args:
1071            source_columns: Mapping of names to source columns.
1072
1073        Returns:
1074            Mapping of column name to source name.
1075        """
1076        if not source_columns:
1077            return {}
1078
1079        source_columns_pairs = list(source_columns.items())
1080
1081        first_table, first_columns = source_columns_pairs[0]
1082
1083        if len(source_columns_pairs) == 1:
1084            # Performance optimization - avoid copying first_columns if there is only one table.
1085            return SingleValuedMapping(first_columns, first_table)
1086
1087        unambiguous_columns = {col: first_table for col in first_columns}
1088        all_columns = set(unambiguous_columns)
1089
1090        for table, columns in source_columns_pairs[1:]:
1091            unique = set(columns)
1092            ambiguous = all_columns.intersection(unique)
1093            all_columns.update(columns)
1094
1095            for column in ambiguous:
1096                unambiguous_columns.pop(column, None)
1097            for column in unique.difference(ambiguous):
1098                unambiguous_columns[column] = table
1099
1100        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27    dialect: DialectType = None,
 28) -> exp.Expression:
 29    """
 30    Rewrite sqlglot AST to have fully qualified columns.
 31
 32    Example:
 33        >>> import sqlglot
 34        >>> schema = {"tbl": {"col": "INT"}}
 35        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 36        >>> qualify_columns(expression, schema).sql()
 37        'SELECT tbl.col AS col FROM tbl'
 38
 39    Args:
 40        expression: Expression to qualify.
 41        schema: Database schema.
 42        expand_alias_refs: Whether to expand references to aliases.
 43        expand_stars: Whether to expand star queries. This is a necessary step
 44            for most of the optimizer's rules to work; do not set to False unless you
 45            know what you're doing!
 46        infer_schema: Whether to infer the schema if missing.
 47        allow_partial_qualification: Whether to allow partial qualification.
 48
 49    Returns:
 50        The qualified expression.
 51
 52    Notes:
 53        - Currently only handles a single PIVOT or UNPIVOT operator
 54    """
 55    schema = ensure_schema(schema, dialect=dialect)
 56    annotator = TypeAnnotator(schema)
 57    infer_schema = schema.empty if infer_schema is None else infer_schema
 58    dialect = Dialect.get_or_raise(schema.dialect)
 59    pseudocolumns = dialect.PSEUDOCOLUMNS
 60    bigquery = dialect == "bigquery"
 61
 62    for scope in traverse_scope(expression):
 63        scope_expression = scope.expression
 64        is_select = isinstance(scope_expression, exp.Select)
 65
 66        if is_select and scope_expression.args.get("connect"):
 67            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 68            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 69            scope_expression.transform(
 70                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 71                copy=False,
 72            )
 73            scope.clear_cache()
 74
 75        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 76        _pop_table_column_aliases(scope.ctes)
 77        _pop_table_column_aliases(scope.derived_tables)
 78        using_column_tables = _expand_using(scope, resolver)
 79
 80        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 81            _expand_alias_refs(
 82                scope,
 83                resolver,
 84                dialect,
 85                expand_only_groupby=bigquery,
 86            )
 87
 88        _convert_columns_to_dots(scope, resolver)
 89        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106
107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
109        _expand_order_by_and_distinct_on(scope, resolver)
110
111        if bigquery:
112            annotator.annotate_scope(scope)
113
114    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E:
118    """Raise an `OptimizeError` if any columns aren't qualified"""
119    all_unqualified_columns = []
120    for scope in traverse_scope(expression):
121        if isinstance(scope.expression, exp.Select):
122            unqualified_columns = scope.unqualified_columns
123
124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125                column = scope.external_columns[0]
126                for_table = f" for table: '{column.table}'" if column.table else ""
127                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
128
129            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
130                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
131                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
132                # this list here to ensure those in the former category will be excluded.
133                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
134                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
135
136            all_unqualified_columns.extend(unqualified_columns)
137
138    if all_unqualified_columns:
139        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
140
141    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
838def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
839    """Ensure all output columns are aliased"""
840    if isinstance(scope_or_expression, exp.Expression):
841        scope = build_scope(scope_or_expression)
842        if not isinstance(scope, Scope):
843            return
844    else:
845        scope = scope_or_expression
846
847    new_selections = []
848    for i, (selection, aliased_column) in enumerate(
849        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
850    ):
851        if selection is None or isinstance(selection, exp.QueryTransform):
852            break
853
854        if isinstance(selection, exp.Subquery):
855            if not selection.output_name:
856                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
857        elif not isinstance(selection, exp.Alias) and not selection.is_star:
858            selection = alias(
859                selection,
860                alias=selection.output_name or f"_col_{i}",
861                copy=False,
862            )
863        if aliased_column:
864            selection.set("alias", exp.to_identifier(aliased_column))
865
866        new_selections.append(selection)
867
868    if new_selections and isinstance(scope.expression, exp.Select):
869        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
872def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
873    """Makes sure all identifiers that need to be quoted are quoted."""
874    return expression.transform(
875        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
876    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
879def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
880    """
881    Pushes down the CTE alias columns into the projection,
882
883    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
884
885    Example:
886        >>> import sqlglot
887        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
888        >>> pushdown_cte_alias_columns(expression).sql()
889        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
890
891    Args:
892        expression: Expression to pushdown.
893
894    Returns:
895        The expression with the CTE aliases pushed down into the projection.
896    """
897    for cte in expression.find_all(exp.CTE):
898        if cte.alias_column_names:
899            new_expressions = []
900            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
901                if isinstance(projection, exp.Alias):
902                    projection.set("alias", _alias)
903                else:
904                    projection = alias(projection, alias=_alias)
905                new_expressions.append(projection)
906            cte.this.set("expressions", new_expressions)
907
908    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
 911class Resolver:
 912    """
 913    Helper for resolving columns.
 914
 915    This is a class so we can lazily load some things and easily share them across functions.
 916    """
 917
 918    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 919        self.scope = scope
 920        self.schema = schema
 921        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 922        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 923        self._all_columns: t.Optional[t.Set[str]] = None
 924        self._infer_schema = infer_schema
 925        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 926
 927    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 928        """
 929        Get the table for a column name.
 930
 931        Args:
 932            column_name: The column name to find the table for.
 933        Returns:
 934            The table name if it can be found/inferred.
 935        """
 936        if self._unambiguous_columns is None:
 937            self._unambiguous_columns = self._get_unambiguous_columns(
 938                self._get_all_source_columns()
 939            )
 940
 941        table_name = self._unambiguous_columns.get(column_name)
 942
 943        if not table_name and self._infer_schema:
 944            sources_without_schema = tuple(
 945                source
 946                for source, columns in self._get_all_source_columns().items()
 947                if not columns or "*" in columns
 948            )
 949            if len(sources_without_schema) == 1:
 950                table_name = sources_without_schema[0]
 951
 952        if table_name not in self.scope.selected_sources:
 953            return exp.to_identifier(table_name)
 954
 955        node, _ = self.scope.selected_sources.get(table_name)
 956
 957        if isinstance(node, exp.Query):
 958            while node and node.alias != table_name:
 959                node = node.parent
 960
 961        node_alias = node.args.get("alias")
 962        if node_alias:
 963            return exp.to_identifier(node_alias.this)
 964
 965        return exp.to_identifier(table_name)
 966
 967    @property
 968    def all_columns(self) -> t.Set[str]:
 969        """All available columns of all sources in this scope"""
 970        if self._all_columns is None:
 971            self._all_columns = {
 972                column for columns in self._get_all_source_columns().values() for column in columns
 973            }
 974        return self._all_columns
 975
 976    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 977        """Resolve the source columns for a given source `name`."""
 978        cache_key = (name, only_visible)
 979        if cache_key not in self._get_source_columns_cache:
 980            if name not in self.scope.sources:
 981                raise OptimizeError(f"Unknown table: {name}")
 982
 983            source = self.scope.sources[name]
 984
 985            if isinstance(source, exp.Table):
 986                columns = self.schema.column_names(source, only_visible)
 987            elif isinstance(source, Scope) and isinstance(
 988                source.expression, (exp.Values, exp.Unnest)
 989            ):
 990                columns = source.expression.named_selects
 991
 992                # in bigquery, unnest structs are automatically scoped as tables, so you can
 993                # directly select a struct field in a query.
 994                # this handles the case where the unnest is statically defined.
 995                if self.schema.dialect == "bigquery":
 996                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 997                        for k in source.expression.type.expressions:  # type: ignore
 998                            columns.append(k.name)
 999            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1000                set_op = source.expression
1001
1002                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1003                on_column_list = set_op.args.get("on")
1004
1005                if on_column_list:
1006                    # The resulting columns are the columns in the ON clause:
1007                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1008                    columns = [col.name for col in on_column_list]
1009                elif set_op.side or set_op.kind:
1010                    side = set_op.side
1011                    kind = set_op.kind
1012
1013                    left = set_op.left.named_selects
1014                    right = set_op.right.named_selects
1015
1016                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
1017                    if side == "LEFT":
1018                        columns = left
1019                    elif side == "FULL":
1020                        columns = list(dict.fromkeys(left + right))
1021                    elif kind == "INNER":
1022                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1023                else:
1024                    columns = set_op.named_selects
1025            else:
1026                select = seq_get(source.expression.selects, 0)
1027
1028                if isinstance(select, exp.QueryTransform):
1029                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1030                    schema = select.args.get("schema")
1031                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1032                else:
1033                    columns = source.expression.named_selects
1034
1035            node, _ = self.scope.selected_sources.get(name) or (None, None)
1036            if isinstance(node, Scope):
1037                column_aliases = node.expression.alias_column_names
1038            elif isinstance(node, exp.Expression):
1039                column_aliases = node.alias_column_names
1040            else:
1041                column_aliases = []
1042
1043            if column_aliases:
1044                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1045                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1046                columns = [
1047                    alias or name
1048                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1049                ]
1050
1051            self._get_source_columns_cache[cache_key] = columns
1052
1053        return self._get_source_columns_cache[cache_key]
1054
1055    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1056        if self._source_columns is None:
1057            self._source_columns = {
1058                source_name: self.get_source_columns(source_name)
1059                for source_name, source in itertools.chain(
1060                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1061                )
1062            }
1063        return self._source_columns
1064
1065    def _get_unambiguous_columns(
1066        self, source_columns: t.Dict[str, t.Sequence[str]]
1067    ) -> t.Mapping[str, str]:
1068        """
1069        Find all the unambiguous columns in sources.
1070
1071        Args:
1072            source_columns: Mapping of names to source columns.
1073
1074        Returns:
1075            Mapping of column name to source name.
1076        """
1077        if not source_columns:
1078            return {}
1079
1080        source_columns_pairs = list(source_columns.items())
1081
1082        first_table, first_columns = source_columns_pairs[0]
1083
1084        if len(source_columns_pairs) == 1:
1085            # Performance optimization - avoid copying first_columns if there is only one table.
1086            return SingleValuedMapping(first_columns, first_table)
1087
1088        unambiguous_columns = {col: first_table for col in first_columns}
1089        all_columns = set(unambiguous_columns)
1090
1091        for table, columns in source_columns_pairs[1:]:
1092            unique = set(columns)
1093            ambiguous = all_columns.intersection(unique)
1094            all_columns.update(columns)
1095
1096            for column in ambiguous:
1097                unambiguous_columns.pop(column, None)
1098            for column in unique.difference(ambiguous):
1099                unambiguous_columns[column] = table
1100
1101        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
918    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
919        self.scope = scope
920        self.schema = schema
921        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
922        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
923        self._all_columns: t.Optional[t.Set[str]] = None
924        self._infer_schema = infer_schema
925        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
927    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
928        """
929        Get the table for a column name.
930
931        Args:
932            column_name: The column name to find the table for.
933        Returns:
934            The table name if it can be found/inferred.
935        """
936        if self._unambiguous_columns is None:
937            self._unambiguous_columns = self._get_unambiguous_columns(
938                self._get_all_source_columns()
939            )
940
941        table_name = self._unambiguous_columns.get(column_name)
942
943        if not table_name and self._infer_schema:
944            sources_without_schema = tuple(
945                source
946                for source, columns in self._get_all_source_columns().items()
947                if not columns or "*" in columns
948            )
949            if len(sources_without_schema) == 1:
950                table_name = sources_without_schema[0]
951
952        if table_name not in self.scope.selected_sources:
953            return exp.to_identifier(table_name)
954
955        node, _ = self.scope.selected_sources.get(table_name)
956
957        if isinstance(node, exp.Query):
958            while node and node.alias != table_name:
959                node = node.parent
960
961        node_alias = node.args.get("alias")
962        if node_alias:
963            return exp.to_identifier(node_alias.this)
964
965        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
967    @property
968    def all_columns(self) -> t.Set[str]:
969        """All available columns of all sources in this scope"""
970        if self._all_columns is None:
971            self._all_columns = {
972                column for columns in self._get_all_source_columns().values() for column in columns
973            }
974        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
 976    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
 977        """Resolve the source columns for a given source `name`."""
 978        cache_key = (name, only_visible)
 979        if cache_key not in self._get_source_columns_cache:
 980            if name not in self.scope.sources:
 981                raise OptimizeError(f"Unknown table: {name}")
 982
 983            source = self.scope.sources[name]
 984
 985            if isinstance(source, exp.Table):
 986                columns = self.schema.column_names(source, only_visible)
 987            elif isinstance(source, Scope) and isinstance(
 988                source.expression, (exp.Values, exp.Unnest)
 989            ):
 990                columns = source.expression.named_selects
 991
 992                # in bigquery, unnest structs are automatically scoped as tables, so you can
 993                # directly select a struct field in a query.
 994                # this handles the case where the unnest is statically defined.
 995                if self.schema.dialect == "bigquery":
 996                    if source.expression.is_type(exp.DataType.Type.STRUCT):
 997                        for k in source.expression.type.expressions:  # type: ignore
 998                            columns.append(k.name)
 999            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1000                set_op = source.expression
1001
1002                # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1003                on_column_list = set_op.args.get("on")
1004
1005                if on_column_list:
1006                    # The resulting columns are the columns in the ON clause:
1007                    # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1008                    columns = [col.name for col in on_column_list]
1009                elif set_op.side or set_op.kind:
1010                    side = set_op.side
1011                    kind = set_op.kind
1012
1013                    left = set_op.left.named_selects
1014                    right = set_op.right.named_selects
1015
1016                    # We use dict.fromkeys to deduplicate keys and maintain insertion order
1017                    if side == "LEFT":
1018                        columns = left
1019                    elif side == "FULL":
1020                        columns = list(dict.fromkeys(left + right))
1021                    elif kind == "INNER":
1022                        columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1023                else:
1024                    columns = set_op.named_selects
1025            else:
1026                select = seq_get(source.expression.selects, 0)
1027
1028                if isinstance(select, exp.QueryTransform):
1029                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1030                    schema = select.args.get("schema")
1031                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1032                else:
1033                    columns = source.expression.named_selects
1034
1035            node, _ = self.scope.selected_sources.get(name) or (None, None)
1036            if isinstance(node, Scope):
1037                column_aliases = node.expression.alias_column_names
1038            elif isinstance(node, exp.Expression):
1039                column_aliases = node.alias_column_names
1040            else:
1041                column_aliases = []
1042
1043            if column_aliases:
1044                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1045                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1046                columns = [
1047                    alias or name
1048                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1049                ]
1050
1051            self._get_source_columns_cache[cache_key] = columns
1052
1053        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.