Edit on GitHub

sqlglot.optimizer.qualify_columns

   1from __future__ import annotations
   2
   3import itertools
   4import typing as t
   5
   6from sqlglot import alias, exp
   7from sqlglot.dialects.dialect import Dialect, DialectType
   8from sqlglot.errors import OptimizeError
   9from sqlglot.helper import seq_get, SingleValuedMapping
  10from sqlglot.optimizer.annotate_types import TypeAnnotator
  11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
  12from sqlglot.optimizer.simplify import simplify_parens
  13from sqlglot.schema import Schema, ensure_schema
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot._typing import E
  17
  18
  19def qualify_columns(
  20    expression: exp.Expression,
  21    schema: t.Dict | Schema,
  22    expand_alias_refs: bool = True,
  23    expand_stars: bool = True,
  24    infer_schema: t.Optional[bool] = None,
  25    allow_partial_qualification: bool = False,
  26    dialect: DialectType = None,
  27) -> exp.Expression:
  28    """
  29    Rewrite sqlglot AST to have fully qualified columns.
  30
  31    Example:
  32        >>> import sqlglot
  33        >>> schema = {"tbl": {"col": "INT"}}
  34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
  35        >>> qualify_columns(expression, schema).sql()
  36        'SELECT tbl.col AS col FROM tbl'
  37
  38    Args:
  39        expression: Expression to qualify.
  40        schema: Database schema.
  41        expand_alias_refs: Whether to expand references to aliases.
  42        expand_stars: Whether to expand star queries. This is a necessary step
  43            for most of the optimizer's rules to work; do not set to False unless you
  44            know what you're doing!
  45        infer_schema: Whether to infer the schema if missing.
  46        allow_partial_qualification: Whether to allow partial qualification.
  47
  48    Returns:
  49        The qualified expression.
  50
  51    Notes:
  52        - Currently only handles a single PIVOT or UNPIVOT operator
  53    """
  54    schema = ensure_schema(schema, dialect=dialect)
  55    annotator = TypeAnnotator(schema)
  56    infer_schema = schema.empty if infer_schema is None else infer_schema
  57    dialect = Dialect.get_or_raise(schema.dialect)
  58    pseudocolumns = dialect.PSEUDOCOLUMNS
  59    bigquery = dialect == "bigquery"
  60
  61    for scope in traverse_scope(expression):
  62        scope_expression = scope.expression
  63        is_select = isinstance(scope_expression, exp.Select)
  64
  65        if is_select and scope_expression.args.get("connect"):
  66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
  67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
  68            scope_expression.transform(
  69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
  70                copy=False,
  71            )
  72            scope.clear_cache()
  73
  74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
  75        _pop_table_column_aliases(scope.ctes)
  76        _pop_table_column_aliases(scope.derived_tables)
  77        using_column_tables = _expand_using(scope, resolver)
  78
  79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
  80            _expand_alias_refs(
  81                scope,
  82                resolver,
  83                dialect,
  84                expand_only_groupby=bigquery,
  85            )
  86
  87        _convert_columns_to_dots(scope, resolver)
  88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
  89
  90        if not schema.empty and expand_alias_refs:
  91            _expand_alias_refs(scope, resolver, dialect)
  92
  93        if is_select:
  94            if expand_stars:
  95                _expand_stars(
  96                    scope,
  97                    resolver,
  98                    using_column_tables,
  99                    pseudocolumns,
 100                    annotator,
 101                )
 102            qualify_outputs(scope)
 103
 104        _expand_group_by(scope, dialect)
 105
 106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
 107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
 108        _expand_order_by_and_distinct_on(scope, resolver)
 109
 110        if bigquery:
 111            annotator.annotate_scope(scope)
 112
 113    return expression
 114
 115
 116def validate_qualify_columns(expression: E) -> E:
 117    """Raise an `OptimizeError` if any columns aren't qualified"""
 118    all_unqualified_columns = []
 119    for scope in traverse_scope(expression):
 120        if isinstance(scope.expression, exp.Select):
 121            unqualified_columns = scope.unqualified_columns
 122
 123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 124                column = scope.external_columns[0]
 125                for_table = f" for table: '{column.table}'" if column.table else ""
 126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 127
 128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 131                # this list here to ensure those in the former category will be excluded.
 132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 134
 135            all_unqualified_columns.extend(unqualified_columns)
 136
 137    if all_unqualified_columns:
 138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
 139
 140    return expression
 141
 142
 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
 144    name_columns = [
 145        field.this
 146        for field in unpivot.fields
 147        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
 148    ]
 149    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
 150
 151    return itertools.chain(name_columns, value_columns)
 152
 153
 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 155    """
 156    Remove table column aliases.
 157
 158    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 159    """
 160    for derived_table in derived_tables:
 161        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
 162            continue
 163        table_alias = derived_table.args.get("alias")
 164        if table_alias:
 165            table_alias.args.pop("columns", None)
 166
 167
 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 169    columns = {}
 170
 171    def _update_source_columns(source_name: str) -> None:
 172        for column_name in resolver.get_source_columns(source_name):
 173            if column_name not in columns:
 174                columns[column_name] = source_name
 175
 176    joins = list(scope.find_all(exp.Join))
 177    names = {join.alias_or_name for join in joins}
 178    ordered = [key for key in scope.selected_sources if key not in names]
 179
 180    if names and not ordered:
 181        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
 182
 183    # Mapping of automatically joined column names to an ordered set of source names (dict).
 184    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
 185
 186    for source_name in ordered:
 187        _update_source_columns(source_name)
 188
 189    for i, join in enumerate(joins):
 190        source_table = ordered[-1]
 191        if source_table:
 192            _update_source_columns(source_table)
 193
 194        join_table = join.alias_or_name
 195        ordered.append(join_table)
 196
 197        using = join.args.get("using")
 198        if not using:
 199            continue
 200
 201        join_columns = resolver.get_source_columns(join_table)
 202        conditions = []
 203        using_identifier_count = len(using)
 204        is_semi_or_anti_join = join.is_semi_or_anti_join
 205
 206        for identifier in using:
 207            identifier = identifier.name
 208            table = columns.get(identifier)
 209
 210            if not table or identifier not in join_columns:
 211                if (columns and "*" not in columns) and join_columns:
 212                    raise OptimizeError(f"Cannot automatically join: {identifier}")
 213
 214            table = table or source_table
 215
 216            if i == 0 or using_identifier_count == 1:
 217                lhs: exp.Expression = exp.column(identifier, table=table)
 218            else:
 219                coalesce_columns = [
 220                    exp.column(identifier, table=t)
 221                    for t in ordered[:-1]
 222                    if identifier in resolver.get_source_columns(t)
 223                ]
 224                if len(coalesce_columns) > 1:
 225                    lhs = exp.func("coalesce", *coalesce_columns)
 226                else:
 227                    lhs = exp.column(identifier, table=table)
 228
 229            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
 230
 231            # Set all values in the dict to None, because we only care about the key ordering
 232            tables = column_tables.setdefault(identifier, {})
 233
 234            # Do not update the dict if this was a SEMI/ANTI join in
 235            # order to avoid generating COALESCE columns for this join pair
 236            if not is_semi_or_anti_join:
 237                if table not in tables:
 238                    tables[table] = None
 239                if join_table not in tables:
 240                    tables[join_table] = None
 241
 242        join.args.pop("using")
 243        join.set("on", exp.and_(*conditions, copy=False))
 244
 245    if column_tables:
 246        for column in scope.columns:
 247            if not column.table and column.name in column_tables:
 248                tables = column_tables[column.name]
 249                coalesce_args = [exp.column(column.name, table=table) for table in tables]
 250                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
 251
 252                if isinstance(column.parent, exp.Select):
 253                    # Ensure the USING column keeps its name if it's projected
 254                    replacement = alias(replacement, alias=column.name, copy=False)
 255                elif isinstance(column.parent, exp.Struct):
 256                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
 257                    replacement = exp.PropertyEQ(
 258                        this=exp.to_identifier(column.name), expression=replacement
 259                    )
 260
 261                scope.replace(column, replacement)
 262
 263    return column_tables
 264
 265
 266def _expand_alias_refs(
 267    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
 268) -> None:
 269    """
 270    Expand references to aliases.
 271    Example:
 272        SELECT y.foo AS bar, bar * 2 AS baz FROM y
 273     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
 274    """
 275    expression = scope.expression
 276
 277    if not isinstance(expression, exp.Select) or dialect == "oracle":
 278        return
 279
 280    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
 281    projections = {s.alias_or_name for s in expression.selects}
 282    is_bigquery = dialect == "bigquery"
 283
 284    def replace_columns(
 285        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
 286    ) -> None:
 287        is_group_by = isinstance(node, exp.Group)
 288        is_having = isinstance(node, exp.Having)
 289        if not node or (expand_only_groupby and not is_group_by):
 290            return
 291
 292        for column in walk_in_scope(node, prune=lambda node: node.is_star):
 293            if not isinstance(column, exp.Column):
 294                continue
 295
 296            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
 297            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
 298            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
 299            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
 300            if expand_only_groupby and is_group_by and column.parent is not node:
 301                continue
 302
 303            skip_replace = False
 304            table = resolver.get_table(column.name) if resolve_table and not column.table else None
 305            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
 306
 307            if alias_expr:
 308                skip_replace = bool(
 309                    alias_expr.find(exp.AggFunc)
 310                    and column.find_ancestor(exp.AggFunc)
 311                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
 312                )
 313
 314                # BigQuery's having clause gets confused if an alias matches a source.
 315                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
 316                # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
 317                # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
 318                if is_having and is_bigquery:
 319                    skip_replace = skip_replace or any(
 320                        node.parts[0].name in projections
 321                        for node in alias_expr.find_all(exp.Column)
 322                    )
 323            elif is_bigquery and (is_group_by or is_having):
 324                column_table = table.name if table else column.table
 325                if column_table in projections:
 326                    # BigQuery's GROUP BY and HAVING clauses get confused if the column name
 327                    # matches a source name and a projection. For instance:
 328                    # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1
 329                    # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table
 330                    # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause"
 331                    column.replace(exp.to_identifier(column.name))
 332                    return
 333
 334            if table and (not alias_expr or skip_replace):
 335                column.set("table", table)
 336            elif not column.table and alias_expr and not skip_replace:
 337                if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and (
 338                    literal_index or resolve_table
 339                ):
 340                    if literal_index:
 341                        column.replace(exp.Literal.number(i))
 342                else:
 343                    column = column.replace(exp.paren(alias_expr))
 344                    simplified = simplify_parens(column, dialect)
 345                    if simplified is not column:
 346                        column.replace(simplified)
 347
 348    for i, projection in enumerate(expression.selects):
 349        replace_columns(projection)
 350        if isinstance(projection, exp.Alias):
 351            alias_to_expression[projection.alias] = (projection.this, i + 1)
 352
 353    parent_scope = scope
 354    on_right_sub_tree = False
 355    while parent_scope and not parent_scope.is_cte:
 356        if parent_scope.is_union:
 357            on_right_sub_tree = parent_scope.parent.expression.right is parent_scope.expression
 358        parent_scope = parent_scope.parent
 359
 360    # We shouldn't expand aliases if they match the recursive CTE's columns
 361    # and we are in the recursive part (right sub tree) of the CTE
 362    if parent_scope and on_right_sub_tree:
 363        cte = parent_scope.expression.parent
 364        if cte.find_ancestor(exp.With).recursive:
 365            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
 366                alias_to_expression.pop(recursive_cte_column.output_name, None)
 367
 368    replace_columns(expression.args.get("where"))
 369    replace_columns(expression.args.get("group"), literal_index=True)
 370    replace_columns(expression.args.get("having"), resolve_table=True)
 371    replace_columns(expression.args.get("qualify"), resolve_table=True)
 372
 373    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
 374    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
 375    if dialect == "snowflake":
 376        for join in expression.args.get("joins") or []:
 377            replace_columns(join)
 378
 379    scope.clear_cache()
 380
 381
 382def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
 383    expression = scope.expression
 384    group = expression.args.get("group")
 385    if not group:
 386        return
 387
 388    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
 389    expression.set("group", group)
 390
 391
 392def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
 393    for modifier_key in ("order", "distinct"):
 394        modifier = scope.expression.args.get(modifier_key)
 395        if isinstance(modifier, exp.Distinct):
 396            modifier = modifier.args.get("on")
 397
 398        if not isinstance(modifier, exp.Expression):
 399            continue
 400
 401        modifier_expressions = modifier.expressions
 402        if modifier_key == "order":
 403            modifier_expressions = [ordered.this for ordered in modifier_expressions]
 404
 405        for original, expanded in zip(
 406            modifier_expressions,
 407            _expand_positional_references(
 408                scope, modifier_expressions, resolver.schema.dialect, alias=True
 409            ),
 410        ):
 411            for agg in original.find_all(exp.AggFunc):
 412                for col in agg.find_all(exp.Column):
 413                    if not col.table:
 414                        col.set("table", resolver.get_table(col.name))
 415
 416            original.replace(expanded)
 417
 418        if scope.expression.args.get("group"):
 419            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
 420
 421            for expression in modifier_expressions:
 422                expression.replace(
 423                    exp.to_identifier(_select_by_pos(scope, expression).alias)
 424                    if expression.is_int
 425                    else selects.get(expression, expression)
 426                )
 427
 428
 429def _expand_positional_references(
 430    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
 431) -> t.List[exp.Expression]:
 432    new_nodes: t.List[exp.Expression] = []
 433    ambiguous_projections = None
 434
 435    for node in expressions:
 436        if node.is_int:
 437            select = _select_by_pos(scope, t.cast(exp.Literal, node))
 438
 439            if alias:
 440                new_nodes.append(exp.column(select.args["alias"].copy()))
 441            else:
 442                select = select.this
 443
 444                if dialect == "bigquery":
 445                    if ambiguous_projections is None:
 446                        # When a projection name is also a source name and it is referenced in the
 447                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
 448                        ambiguous_projections = {
 449                            s.alias_or_name
 450                            for s in scope.expression.selects
 451                            if s.alias_or_name in scope.selected_sources
 452                        }
 453
 454                    ambiguous = any(
 455                        column.parts[0].name in ambiguous_projections
 456                        for column in select.find_all(exp.Column)
 457                    )
 458                else:
 459                    ambiguous = False
 460
 461                if (
 462                    isinstance(select, exp.CONSTANTS)
 463                    or select.is_number
 464                    or select.find(exp.Explode, exp.Unnest)
 465                    or ambiguous
 466                ):
 467                    new_nodes.append(node)
 468                else:
 469                    new_nodes.append(select.copy())
 470        else:
 471            new_nodes.append(node)
 472
 473    return new_nodes
 474
 475
 476def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
 477    try:
 478        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
 479    except IndexError:
 480        raise OptimizeError(f"Unknown output column: {node.name}")
 481
 482
 483def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
 484    """
 485    Converts `Column` instances that represent struct field lookup into chained `Dots`.
 486
 487    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
 488    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
 489    """
 490    converted = False
 491    for column in itertools.chain(scope.columns, scope.stars):
 492        if isinstance(column, exp.Dot):
 493            continue
 494
 495        column_table: t.Optional[str | exp.Identifier] = column.table
 496        if (
 497            column_table
 498            and column_table not in scope.sources
 499            and (
 500                not scope.parent
 501                or column_table not in scope.parent.sources
 502                or not scope.is_correlated_subquery
 503            )
 504        ):
 505            root, *parts = column.parts
 506
 507            if root.name in scope.sources:
 508                # The struct is already qualified, but we still need to change the AST
 509                column_table = root
 510                root, *parts = parts
 511            else:
 512                column_table = resolver.get_table(root.name)
 513
 514            if column_table:
 515                converted = True
 516                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
 517
 518    if converted:
 519        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
 520        # a `for column in scope.columns` iteration, even though they shouldn't be
 521        scope.clear_cache()
 522
 523
 524def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
 525    """Disambiguate columns, ensuring each column specifies a source"""
 526    for column in scope.columns:
 527        column_table = column.table
 528        column_name = column.name
 529
 530        if column_table and column_table in scope.sources:
 531            source_columns = resolver.get_source_columns(column_table)
 532            if (
 533                not allow_partial_qualification
 534                and source_columns
 535                and column_name not in source_columns
 536                and "*" not in source_columns
 537            ):
 538                raise OptimizeError(f"Unknown column: {column_name}")
 539
 540        if not column_table:
 541            if scope.pivots and not column.find_ancestor(exp.Pivot):
 542                # If the column is under the Pivot expression, we need to qualify it
 543                # using the name of the pivoted source instead of the pivot's alias
 544                column.set("table", exp.to_identifier(scope.pivots[0].alias))
 545                continue
 546
 547            # column_table can be a '' because bigquery unnest has no table alias
 548            column_table = resolver.get_table(column_name)
 549            if column_table:
 550                column.set("table", column_table)
 551            elif (
 552                resolver.schema.dialect == "bigquery"
 553                and len(column.parts) == 1
 554                and column_name in scope.selected_sources
 555            ):
 556                # BigQuery allows tables to be referenced as columns, treating them as structs
 557                scope.replace(column, exp.TableColumn(this=column.this))
 558
 559    for pivot in scope.pivots:
 560        for column in pivot.find_all(exp.Column):
 561            if not column.table and column.name in resolver.all_columns:
 562                column_table = resolver.get_table(column.name)
 563                if column_table:
 564                    column.set("table", column_table)
 565
 566
 567def _expand_struct_stars_bigquery(
 568    expression: exp.Dot,
 569) -> t.List[exp.Alias]:
 570    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
 571
 572    dot_column = expression.find(exp.Column)
 573    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
 574        return []
 575
 576    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
 577    dot_column = dot_column.copy()
 578    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
 579
 580    # First part is the table name and last part is the star so they can be dropped
 581    dot_parts = expression.parts[1:-1]
 582
 583    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
 584    for part in dot_parts[1:]:
 585        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 586            # Unable to expand star unless all fields are named
 587            if not isinstance(field.this, exp.Identifier):
 588                return []
 589
 590            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
 591                starting_struct = field
 592                break
 593        else:
 594            # There is no matching field in the struct
 595            return []
 596
 597    taken_names = set()
 598    new_selections = []
 599
 600    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 601        name = field.name
 602
 603        # Ambiguous or anonymous fields can't be expanded
 604        if name in taken_names or not isinstance(field.this, exp.Identifier):
 605            return []
 606
 607        taken_names.add(name)
 608
 609        this = field.this.copy()
 610        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
 611        new_column = exp.column(
 612            t.cast(exp.Identifier, root),
 613            table=dot_column.args.get("table"),
 614            fields=t.cast(t.List[exp.Identifier], parts),
 615        )
 616        new_selections.append(alias(new_column, this, copy=False))
 617
 618    return new_selections
 619
 620
 621def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]:
 622    """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
 623
 624    # it is not (<sub_exp>).* pattern, which means we can't expand
 625    if not isinstance(expression.this, exp.Paren):
 626        return []
 627
 628    # find column definition to get data-type
 629    dot_column = expression.find(exp.Column)
 630    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
 631        return []
 632
 633    parent = dot_column.parent
 634    starting_struct = dot_column.type
 635
 636    # walk up AST and down into struct definition in sync
 637    while parent is not None:
 638        if isinstance(parent, exp.Paren):
 639            parent = parent.parent
 640            continue
 641
 642        # if parent is not a dot, then something is wrong
 643        if not isinstance(parent, exp.Dot):
 644            return []
 645
 646        # if the rhs of the dot is star we are done
 647        rhs = parent.right
 648        if isinstance(rhs, exp.Star):
 649            break
 650
 651        # if it is not identifier, then something is wrong
 652        if not isinstance(rhs, exp.Identifier):
 653            return []
 654
 655        # Check if current rhs identifier is in struct
 656        matched = False
 657        for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 658            if struct_field_def.name == rhs.name:
 659                matched = True
 660                starting_struct = struct_field_def.kind  # update struct
 661                break
 662
 663        if not matched:
 664            return []
 665
 666        parent = parent.parent
 667
 668    # build new aliases to expand star
 669    new_selections = []
 670
 671    # fetch the outermost parentheses for new aliaes
 672    outer_paren = expression.this
 673
 674    for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 675        new_identifier = struct_field_def.this.copy()
 676        new_dot = exp.Dot.build([outer_paren.copy(), new_identifier])
 677        new_alias = alias(new_dot, new_identifier, copy=False)
 678        new_selections.append(new_alias)
 679
 680    return new_selections
 681
 682
 683def _expand_stars(
 684    scope: Scope,
 685    resolver: Resolver,
 686    using_column_tables: t.Dict[str, t.Any],
 687    pseudocolumns: t.Set[str],
 688    annotator: TypeAnnotator,
 689) -> None:
 690    """Expand stars to lists of column selections"""
 691
 692    new_selections: t.List[exp.Expression] = []
 693    except_columns: t.Dict[int, t.Set[str]] = {}
 694    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
 695    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
 696
 697    coalesced_columns = set()
 698    dialect = resolver.schema.dialect
 699
 700    pivot_output_columns = None
 701    pivot_exclude_columns: t.Set[str] = set()
 702
 703    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
 704    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
 705        if pivot.unpivot:
 706            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
 707
 708            for field in pivot.fields:
 709                if isinstance(field, exp.In):
 710                    pivot_exclude_columns.update(
 711                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
 712                    )
 713
 714        else:
 715            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
 716
 717            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
 718            if not pivot_output_columns:
 719                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
 720
 721    is_bigquery = dialect == "bigquery"
 722    is_risingwave = dialect == "risingwave"
 723
 724    if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars):
 725        # Found struct expansion, annotate scope ahead of time
 726        annotator.annotate_scope(scope)
 727
 728    for expression in scope.expression.selects:
 729        tables = []
 730        if isinstance(expression, exp.Star):
 731            tables.extend(scope.selected_sources)
 732            _add_except_columns(expression, tables, except_columns)
 733            _add_replace_columns(expression, tables, replace_columns)
 734            _add_rename_columns(expression, tables, rename_columns)
 735        elif expression.is_star:
 736            if not isinstance(expression, exp.Dot):
 737                tables.append(expression.table)
 738                _add_except_columns(expression.this, tables, except_columns)
 739                _add_replace_columns(expression.this, tables, replace_columns)
 740                _add_rename_columns(expression.this, tables, rename_columns)
 741            elif is_bigquery:
 742                struct_fields = _expand_struct_stars_bigquery(expression)
 743                if struct_fields:
 744                    new_selections.extend(struct_fields)
 745                    continue
 746            elif is_risingwave:
 747                struct_fields = _expand_struct_stars_risingwave(expression)
 748                if struct_fields:
 749                    new_selections.extend(struct_fields)
 750                    continue
 751
 752        if not tables:
 753            new_selections.append(expression)
 754            continue
 755
 756        for table in tables:
 757            if table not in scope.sources:
 758                raise OptimizeError(f"Unknown table: {table}")
 759
 760            columns = resolver.get_source_columns(table, only_visible=True)
 761            columns = columns or scope.outer_columns
 762
 763            if pseudocolumns:
 764                columns = [name for name in columns if name.upper() not in pseudocolumns]
 765
 766            if not columns or "*" in columns:
 767                return
 768
 769            table_id = id(table)
 770            columns_to_exclude = except_columns.get(table_id) or set()
 771            renamed_columns = rename_columns.get(table_id, {})
 772            replaced_columns = replace_columns.get(table_id, {})
 773
 774            if pivot:
 775                if pivot_output_columns and pivot_exclude_columns:
 776                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
 777                    pivot_columns.extend(pivot_output_columns)
 778                else:
 779                    pivot_columns = pivot.alias_column_names
 780
 781                if pivot_columns:
 782                    new_selections.extend(
 783                        alias(exp.column(name, table=pivot.alias), name, copy=False)
 784                        for name in pivot_columns
 785                        if name not in columns_to_exclude
 786                    )
 787                    continue
 788
 789            for name in columns:
 790                if name in columns_to_exclude or name in coalesced_columns:
 791                    continue
 792                if name in using_column_tables and table in using_column_tables[name]:
 793                    coalesced_columns.add(name)
 794                    tables = using_column_tables[name]
 795                    coalesce_args = [exp.column(name, table=table) for table in tables]
 796
 797                    new_selections.append(
 798                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
 799                    )
 800                else:
 801                    alias_ = renamed_columns.get(name, name)
 802                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
 803                    new_selections.append(
 804                        alias(selection_expr, alias_, copy=False)
 805                        if alias_ != name
 806                        else selection_expr
 807                    )
 808
 809    # Ensures we don't overwrite the initial selections with an empty list
 810    if new_selections and isinstance(scope.expression, exp.Select):
 811        scope.expression.set("expressions", new_selections)
 812
 813
 814def _add_except_columns(
 815    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
 816) -> None:
 817    except_ = expression.args.get("except")
 818
 819    if not except_:
 820        return
 821
 822    columns = {e.name for e in except_}
 823
 824    for table in tables:
 825        except_columns[id(table)] = columns
 826
 827
 828def _add_rename_columns(
 829    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
 830) -> None:
 831    rename = expression.args.get("rename")
 832
 833    if not rename:
 834        return
 835
 836    columns = {e.this.name: e.alias for e in rename}
 837
 838    for table in tables:
 839        rename_columns[id(table)] = columns
 840
 841
 842def _add_replace_columns(
 843    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
 844) -> None:
 845    replace = expression.args.get("replace")
 846
 847    if not replace:
 848        return
 849
 850    columns = {e.alias: e for e in replace}
 851
 852    for table in tables:
 853        replace_columns[id(table)] = columns
 854
 855
 856def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
 857    """Ensure all output columns are aliased"""
 858    if isinstance(scope_or_expression, exp.Expression):
 859        scope = build_scope(scope_or_expression)
 860        if not isinstance(scope, Scope):
 861            return
 862    else:
 863        scope = scope_or_expression
 864
 865    new_selections = []
 866    for i, (selection, aliased_column) in enumerate(
 867        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
 868    ):
 869        if selection is None or isinstance(selection, exp.QueryTransform):
 870            break
 871
 872        if isinstance(selection, exp.Subquery):
 873            if not selection.output_name:
 874                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
 875        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
 876            selection = alias(
 877                selection,
 878                alias=selection.output_name or f"_col_{i}",
 879                copy=False,
 880            )
 881        if aliased_column:
 882            selection.set("alias", exp.to_identifier(aliased_column))
 883
 884        new_selections.append(selection)
 885
 886    if new_selections and isinstance(scope.expression, exp.Select):
 887        scope.expression.set("expressions", new_selections)
 888
 889
 890def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
 891    """Makes sure all identifiers that need to be quoted are quoted."""
 892    return expression.transform(
 893        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
 894    )  # type: ignore
 895
 896
 897def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
 898    """
 899    Pushes down the CTE alias columns into the projection,
 900
 901    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
 902
 903    Example:
 904        >>> import sqlglot
 905        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
 906        >>> pushdown_cte_alias_columns(expression).sql()
 907        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
 908
 909    Args:
 910        expression: Expression to pushdown.
 911
 912    Returns:
 913        The expression with the CTE aliases pushed down into the projection.
 914    """
 915    for cte in expression.find_all(exp.CTE):
 916        if cte.alias_column_names:
 917            new_expressions = []
 918            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
 919                if isinstance(projection, exp.Alias):
 920                    projection.set("alias", _alias)
 921                else:
 922                    projection = alias(projection, alias=_alias)
 923                new_expressions.append(projection)
 924            cte.this.set("expressions", new_expressions)
 925
 926    return expression
 927
 928
 929class Resolver:
 930    """
 931    Helper for resolving columns.
 932
 933    This is a class so we can lazily load some things and easily share them across functions.
 934    """
 935
 936    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 937        self.scope = scope
 938        self.schema = schema
 939        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 940        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 941        self._all_columns: t.Optional[t.Set[str]] = None
 942        self._infer_schema = infer_schema
 943        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 944
 945    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 946        """
 947        Get the table for a column name.
 948
 949        Args:
 950            column_name: The column name to find the table for.
 951        Returns:
 952            The table name if it can be found/inferred.
 953        """
 954        if self._unambiguous_columns is None:
 955            self._unambiguous_columns = self._get_unambiguous_columns(
 956                self._get_all_source_columns()
 957            )
 958
 959        table_name = self._unambiguous_columns.get(column_name)
 960
 961        if not table_name and self._infer_schema:
 962            sources_without_schema = tuple(
 963                source
 964                for source, columns in self._get_all_source_columns().items()
 965                if not columns or "*" in columns
 966            )
 967            if len(sources_without_schema) == 1:
 968                table_name = sources_without_schema[0]
 969
 970        if table_name not in self.scope.selected_sources:
 971            return exp.to_identifier(table_name)
 972
 973        node, _ = self.scope.selected_sources.get(table_name)
 974
 975        if isinstance(node, exp.Query):
 976            while node and node.alias != table_name:
 977                node = node.parent
 978
 979        node_alias = node.args.get("alias")
 980        if node_alias:
 981            return exp.to_identifier(node_alias.this)
 982
 983        return exp.to_identifier(table_name)
 984
 985    @property
 986    def all_columns(self) -> t.Set[str]:
 987        """All available columns of all sources in this scope"""
 988        if self._all_columns is None:
 989            self._all_columns = {
 990                column for columns in self._get_all_source_columns().values() for column in columns
 991            }
 992        return self._all_columns
 993
 994    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 995        if isinstance(expression, exp.Select):
 996            return expression.named_selects
 997        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 998            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
 999            return self.get_source_columns_from_set_op(expression.this)
1000        if not isinstance(expression, exp.SetOperation):
1001            raise OptimizeError(f"Unknown set operation: {expression}")
1002
1003        set_op = expression
1004
1005        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1006        on_column_list = set_op.args.get("on")
1007
1008        if on_column_list:
1009            # The resulting columns are the columns in the ON clause:
1010            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1011            columns = [col.name for col in on_column_list]
1012        elif set_op.side or set_op.kind:
1013            side = set_op.side
1014            kind = set_op.kind
1015
1016            # Visit the children UNIONs (if any) in a post-order traversal
1017            left = self.get_source_columns_from_set_op(set_op.left)
1018            right = self.get_source_columns_from_set_op(set_op.right)
1019
1020            # We use dict.fromkeys to deduplicate keys and maintain insertion order
1021            if side == "LEFT":
1022                columns = left
1023            elif side == "FULL":
1024                columns = list(dict.fromkeys(left + right))
1025            elif kind == "INNER":
1026                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1027        else:
1028            columns = set_op.named_selects
1029
1030        return columns
1031
1032    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1033        """Resolve the source columns for a given source `name`."""
1034        cache_key = (name, only_visible)
1035        if cache_key not in self._get_source_columns_cache:
1036            if name not in self.scope.sources:
1037                raise OptimizeError(f"Unknown table: {name}")
1038
1039            source = self.scope.sources[name]
1040
1041            if isinstance(source, exp.Table):
1042                columns = self.schema.column_names(source, only_visible)
1043            elif isinstance(source, Scope) and isinstance(
1044                source.expression, (exp.Values, exp.Unnest)
1045            ):
1046                columns = source.expression.named_selects
1047
1048                # in bigquery, unnest structs are automatically scoped as tables, so you can
1049                # directly select a struct field in a query.
1050                # this handles the case where the unnest is statically defined.
1051                if self.schema.dialect == "bigquery":
1052                    if source.expression.is_type(exp.DataType.Type.STRUCT):
1053                        for k in source.expression.type.expressions:  # type: ignore
1054                            columns.append(k.name)
1055            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1056                columns = self.get_source_columns_from_set_op(source.expression)
1057
1058            else:
1059                select = seq_get(source.expression.selects, 0)
1060
1061                if isinstance(select, exp.QueryTransform):
1062                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1063                    schema = select.args.get("schema")
1064                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1065                else:
1066                    columns = source.expression.named_selects
1067
1068            node, _ = self.scope.selected_sources.get(name) or (None, None)
1069            if isinstance(node, Scope):
1070                column_aliases = node.expression.alias_column_names
1071            elif isinstance(node, exp.Expression):
1072                column_aliases = node.alias_column_names
1073            else:
1074                column_aliases = []
1075
1076            if column_aliases:
1077                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1078                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1079                columns = [
1080                    alias or name
1081                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1082                ]
1083
1084            self._get_source_columns_cache[cache_key] = columns
1085
1086        return self._get_source_columns_cache[cache_key]
1087
1088    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1089        if self._source_columns is None:
1090            self._source_columns = {
1091                source_name: self.get_source_columns(source_name)
1092                for source_name, source in itertools.chain(
1093                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1094                )
1095            }
1096        return self._source_columns
1097
1098    def _get_unambiguous_columns(
1099        self, source_columns: t.Dict[str, t.Sequence[str]]
1100    ) -> t.Mapping[str, str]:
1101        """
1102        Find all the unambiguous columns in sources.
1103
1104        Args:
1105            source_columns: Mapping of names to source columns.
1106
1107        Returns:
1108            Mapping of column name to source name.
1109        """
1110        if not source_columns:
1111            return {}
1112
1113        source_columns_pairs = list(source_columns.items())
1114
1115        first_table, first_columns = source_columns_pairs[0]
1116
1117        if len(source_columns_pairs) == 1:
1118            # Performance optimization - avoid copying first_columns if there is only one table.
1119            return SingleValuedMapping(first_columns, first_table)
1120
1121        unambiguous_columns = {col: first_table for col in first_columns}
1122        all_columns = set(unambiguous_columns)
1123
1124        for table, columns in source_columns_pairs[1:]:
1125            unique = set(columns)
1126            ambiguous = all_columns.intersection(unique)
1127            all_columns.update(columns)
1128
1129            for column in ambiguous:
1130                unambiguous_columns.pop(column, None)
1131            for column in unique.difference(ambiguous):
1132                unambiguous_columns[column] = table
1133
1134        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27    dialect: DialectType = None,
 28) -> exp.Expression:
 29    """
 30    Rewrite sqlglot AST to have fully qualified columns.
 31
 32    Example:
 33        >>> import sqlglot
 34        >>> schema = {"tbl": {"col": "INT"}}
 35        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 36        >>> qualify_columns(expression, schema).sql()
 37        'SELECT tbl.col AS col FROM tbl'
 38
 39    Args:
 40        expression: Expression to qualify.
 41        schema: Database schema.
 42        expand_alias_refs: Whether to expand references to aliases.
 43        expand_stars: Whether to expand star queries. This is a necessary step
 44            for most of the optimizer's rules to work; do not set to False unless you
 45            know what you're doing!
 46        infer_schema: Whether to infer the schema if missing.
 47        allow_partial_qualification: Whether to allow partial qualification.
 48
 49    Returns:
 50        The qualified expression.
 51
 52    Notes:
 53        - Currently only handles a single PIVOT or UNPIVOT operator
 54    """
 55    schema = ensure_schema(schema, dialect=dialect)
 56    annotator = TypeAnnotator(schema)
 57    infer_schema = schema.empty if infer_schema is None else infer_schema
 58    dialect = Dialect.get_or_raise(schema.dialect)
 59    pseudocolumns = dialect.PSEUDOCOLUMNS
 60    bigquery = dialect == "bigquery"
 61
 62    for scope in traverse_scope(expression):
 63        scope_expression = scope.expression
 64        is_select = isinstance(scope_expression, exp.Select)
 65
 66        if is_select and scope_expression.args.get("connect"):
 67            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 68            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 69            scope_expression.transform(
 70                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 71                copy=False,
 72            )
 73            scope.clear_cache()
 74
 75        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 76        _pop_table_column_aliases(scope.ctes)
 77        _pop_table_column_aliases(scope.derived_tables)
 78        using_column_tables = _expand_using(scope, resolver)
 79
 80        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 81            _expand_alias_refs(
 82                scope,
 83                resolver,
 84                dialect,
 85                expand_only_groupby=bigquery,
 86            )
 87
 88        _convert_columns_to_dots(scope, resolver)
 89        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106
107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
109        _expand_order_by_and_distinct_on(scope, resolver)
110
111        if bigquery:
112            annotator.annotate_scope(scope)
113
114    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E:
118    """Raise an `OptimizeError` if any columns aren't qualified"""
119    all_unqualified_columns = []
120    for scope in traverse_scope(expression):
121        if isinstance(scope.expression, exp.Select):
122            unqualified_columns = scope.unqualified_columns
123
124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125                column = scope.external_columns[0]
126                for_table = f" for table: '{column.table}'" if column.table else ""
127                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
128
129            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
130                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
131                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
132                # this list here to ensure those in the former category will be excluded.
133                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
134                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
135
136            all_unqualified_columns.extend(unqualified_columns)
137
138    if all_unqualified_columns:
139        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
140
141    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
857def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
858    """Ensure all output columns are aliased"""
859    if isinstance(scope_or_expression, exp.Expression):
860        scope = build_scope(scope_or_expression)
861        if not isinstance(scope, Scope):
862            return
863    else:
864        scope = scope_or_expression
865
866    new_selections = []
867    for i, (selection, aliased_column) in enumerate(
868        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
869    ):
870        if selection is None or isinstance(selection, exp.QueryTransform):
871            break
872
873        if isinstance(selection, exp.Subquery):
874            if not selection.output_name:
875                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
876        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
877            selection = alias(
878                selection,
879                alias=selection.output_name or f"_col_{i}",
880                copy=False,
881            )
882        if aliased_column:
883            selection.set("alias", exp.to_identifier(aliased_column))
884
885        new_selections.append(selection)
886
887    if new_selections and isinstance(scope.expression, exp.Select):
888        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
891def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
892    """Makes sure all identifiers that need to be quoted are quoted."""
893    return expression.transform(
894        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
895    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
898def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
899    """
900    Pushes down the CTE alias columns into the projection,
901
902    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
903
904    Example:
905        >>> import sqlglot
906        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
907        >>> pushdown_cte_alias_columns(expression).sql()
908        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
909
910    Args:
911        expression: Expression to pushdown.
912
913    Returns:
914        The expression with the CTE aliases pushed down into the projection.
915    """
916    for cte in expression.find_all(exp.CTE):
917        if cte.alias_column_names:
918            new_expressions = []
919            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
920                if isinstance(projection, exp.Alias):
921                    projection.set("alias", _alias)
922                else:
923                    projection = alias(projection, alias=_alias)
924                new_expressions.append(projection)
925            cte.this.set("expressions", new_expressions)
926
927    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
 930class Resolver:
 931    """
 932    Helper for resolving columns.
 933
 934    This is a class so we can lazily load some things and easily share them across functions.
 935    """
 936
 937    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 938        self.scope = scope
 939        self.schema = schema
 940        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 941        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 942        self._all_columns: t.Optional[t.Set[str]] = None
 943        self._infer_schema = infer_schema
 944        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 945
 946    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 947        """
 948        Get the table for a column name.
 949
 950        Args:
 951            column_name: The column name to find the table for.
 952        Returns:
 953            The table name if it can be found/inferred.
 954        """
 955        if self._unambiguous_columns is None:
 956            self._unambiguous_columns = self._get_unambiguous_columns(
 957                self._get_all_source_columns()
 958            )
 959
 960        table_name = self._unambiguous_columns.get(column_name)
 961
 962        if not table_name and self._infer_schema:
 963            sources_without_schema = tuple(
 964                source
 965                for source, columns in self._get_all_source_columns().items()
 966                if not columns or "*" in columns
 967            )
 968            if len(sources_without_schema) == 1:
 969                table_name = sources_without_schema[0]
 970
 971        if table_name not in self.scope.selected_sources:
 972            return exp.to_identifier(table_name)
 973
 974        node, _ = self.scope.selected_sources.get(table_name)
 975
 976        if isinstance(node, exp.Query):
 977            while node and node.alias != table_name:
 978                node = node.parent
 979
 980        node_alias = node.args.get("alias")
 981        if node_alias:
 982            return exp.to_identifier(node_alias.this)
 983
 984        return exp.to_identifier(table_name)
 985
 986    @property
 987    def all_columns(self) -> t.Set[str]:
 988        """All available columns of all sources in this scope"""
 989        if self._all_columns is None:
 990            self._all_columns = {
 991                column for columns in self._get_all_source_columns().values() for column in columns
 992            }
 993        return self._all_columns
 994
 995    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 996        if isinstance(expression, exp.Select):
 997            return expression.named_selects
 998        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 999            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
1000            return self.get_source_columns_from_set_op(expression.this)
1001        if not isinstance(expression, exp.SetOperation):
1002            raise OptimizeError(f"Unknown set operation: {expression}")
1003
1004        set_op = expression
1005
1006        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1007        on_column_list = set_op.args.get("on")
1008
1009        if on_column_list:
1010            # The resulting columns are the columns in the ON clause:
1011            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1012            columns = [col.name for col in on_column_list]
1013        elif set_op.side or set_op.kind:
1014            side = set_op.side
1015            kind = set_op.kind
1016
1017            # Visit the children UNIONs (if any) in a post-order traversal
1018            left = self.get_source_columns_from_set_op(set_op.left)
1019            right = self.get_source_columns_from_set_op(set_op.right)
1020
1021            # We use dict.fromkeys to deduplicate keys and maintain insertion order
1022            if side == "LEFT":
1023                columns = left
1024            elif side == "FULL":
1025                columns = list(dict.fromkeys(left + right))
1026            elif kind == "INNER":
1027                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1028        else:
1029            columns = set_op.named_selects
1030
1031        return columns
1032
1033    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1034        """Resolve the source columns for a given source `name`."""
1035        cache_key = (name, only_visible)
1036        if cache_key not in self._get_source_columns_cache:
1037            if name not in self.scope.sources:
1038                raise OptimizeError(f"Unknown table: {name}")
1039
1040            source = self.scope.sources[name]
1041
1042            if isinstance(source, exp.Table):
1043                columns = self.schema.column_names(source, only_visible)
1044            elif isinstance(source, Scope) and isinstance(
1045                source.expression, (exp.Values, exp.Unnest)
1046            ):
1047                columns = source.expression.named_selects
1048
1049                # in bigquery, unnest structs are automatically scoped as tables, so you can
1050                # directly select a struct field in a query.
1051                # this handles the case where the unnest is statically defined.
1052                if self.schema.dialect == "bigquery":
1053                    if source.expression.is_type(exp.DataType.Type.STRUCT):
1054                        for k in source.expression.type.expressions:  # type: ignore
1055                            columns.append(k.name)
1056            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1057                columns = self.get_source_columns_from_set_op(source.expression)
1058
1059            else:
1060                select = seq_get(source.expression.selects, 0)
1061
1062                if isinstance(select, exp.QueryTransform):
1063                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1064                    schema = select.args.get("schema")
1065                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1066                else:
1067                    columns = source.expression.named_selects
1068
1069            node, _ = self.scope.selected_sources.get(name) or (None, None)
1070            if isinstance(node, Scope):
1071                column_aliases = node.expression.alias_column_names
1072            elif isinstance(node, exp.Expression):
1073                column_aliases = node.alias_column_names
1074            else:
1075                column_aliases = []
1076
1077            if column_aliases:
1078                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1079                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1080                columns = [
1081                    alias or name
1082                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1083                ]
1084
1085            self._get_source_columns_cache[cache_key] = columns
1086
1087        return self._get_source_columns_cache[cache_key]
1088
1089    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1090        if self._source_columns is None:
1091            self._source_columns = {
1092                source_name: self.get_source_columns(source_name)
1093                for source_name, source in itertools.chain(
1094                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1095                )
1096            }
1097        return self._source_columns
1098
1099    def _get_unambiguous_columns(
1100        self, source_columns: t.Dict[str, t.Sequence[str]]
1101    ) -> t.Mapping[str, str]:
1102        """
1103        Find all the unambiguous columns in sources.
1104
1105        Args:
1106            source_columns: Mapping of names to source columns.
1107
1108        Returns:
1109            Mapping of column name to source name.
1110        """
1111        if not source_columns:
1112            return {}
1113
1114        source_columns_pairs = list(source_columns.items())
1115
1116        first_table, first_columns = source_columns_pairs[0]
1117
1118        if len(source_columns_pairs) == 1:
1119            # Performance optimization - avoid copying first_columns if there is only one table.
1120            return SingleValuedMapping(first_columns, first_table)
1121
1122        unambiguous_columns = {col: first_table for col in first_columns}
1123        all_columns = set(unambiguous_columns)
1124
1125        for table, columns in source_columns_pairs[1:]:
1126            unique = set(columns)
1127            ambiguous = all_columns.intersection(unique)
1128            all_columns.update(columns)
1129
1130            for column in ambiguous:
1131                unambiguous_columns.pop(column, None)
1132            for column in unique.difference(ambiguous):
1133                unambiguous_columns[column] = table
1134
1135        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
937    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
938        self.scope = scope
939        self.schema = schema
940        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
941        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
942        self._all_columns: t.Optional[t.Set[str]] = None
943        self._infer_schema = infer_schema
944        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
946    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
947        """
948        Get the table for a column name.
949
950        Args:
951            column_name: The column name to find the table for.
952        Returns:
953            The table name if it can be found/inferred.
954        """
955        if self._unambiguous_columns is None:
956            self._unambiguous_columns = self._get_unambiguous_columns(
957                self._get_all_source_columns()
958            )
959
960        table_name = self._unambiguous_columns.get(column_name)
961
962        if not table_name and self._infer_schema:
963            sources_without_schema = tuple(
964                source
965                for source, columns in self._get_all_source_columns().items()
966                if not columns or "*" in columns
967            )
968            if len(sources_without_schema) == 1:
969                table_name = sources_without_schema[0]
970
971        if table_name not in self.scope.selected_sources:
972            return exp.to_identifier(table_name)
973
974        node, _ = self.scope.selected_sources.get(table_name)
975
976        if isinstance(node, exp.Query):
977            while node and node.alias != table_name:
978                node = node.parent
979
980        node_alias = node.args.get("alias")
981        if node_alias:
982            return exp.to_identifier(node_alias.this)
983
984        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
986    @property
987    def all_columns(self) -> t.Set[str]:
988        """All available columns of all sources in this scope"""
989        if self._all_columns is None:
990            self._all_columns = {
991                column for columns in self._get_all_source_columns().values() for column in columns
992            }
993        return self._all_columns

All available columns of all sources in this scope

def get_source_columns_from_set_op(self, expression: sqlglot.expressions.Expression) -> List[str]:
 995    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 996        if isinstance(expression, exp.Select):
 997            return expression.named_selects
 998        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 999            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
1000            return self.get_source_columns_from_set_op(expression.this)
1001        if not isinstance(expression, exp.SetOperation):
1002            raise OptimizeError(f"Unknown set operation: {expression}")
1003
1004        set_op = expression
1005
1006        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1007        on_column_list = set_op.args.get("on")
1008
1009        if on_column_list:
1010            # The resulting columns are the columns in the ON clause:
1011            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1012            columns = [col.name for col in on_column_list]
1013        elif set_op.side or set_op.kind:
1014            side = set_op.side
1015            kind = set_op.kind
1016
1017            # Visit the children UNIONs (if any) in a post-order traversal
1018            left = self.get_source_columns_from_set_op(set_op.left)
1019            right = self.get_source_columns_from_set_op(set_op.right)
1020
1021            # We use dict.fromkeys to deduplicate keys and maintain insertion order
1022            if side == "LEFT":
1023                columns = left
1024            elif side == "FULL":
1025                columns = list(dict.fromkeys(left + right))
1026            elif kind == "INNER":
1027                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1028        else:
1029            columns = set_op.named_selects
1030
1031        return columns
def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
1033    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1034        """Resolve the source columns for a given source `name`."""
1035        cache_key = (name, only_visible)
1036        if cache_key not in self._get_source_columns_cache:
1037            if name not in self.scope.sources:
1038                raise OptimizeError(f"Unknown table: {name}")
1039
1040            source = self.scope.sources[name]
1041
1042            if isinstance(source, exp.Table):
1043                columns = self.schema.column_names(source, only_visible)
1044            elif isinstance(source, Scope) and isinstance(
1045                source.expression, (exp.Values, exp.Unnest)
1046            ):
1047                columns = source.expression.named_selects
1048
1049                # in bigquery, unnest structs are automatically scoped as tables, so you can
1050                # directly select a struct field in a query.
1051                # this handles the case where the unnest is statically defined.
1052                if self.schema.dialect == "bigquery":
1053                    if source.expression.is_type(exp.DataType.Type.STRUCT):
1054                        for k in source.expression.type.expressions:  # type: ignore
1055                            columns.append(k.name)
1056            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1057                columns = self.get_source_columns_from_set_op(source.expression)
1058
1059            else:
1060                select = seq_get(source.expression.selects, 0)
1061
1062                if isinstance(select, exp.QueryTransform):
1063                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1064                    schema = select.args.get("schema")
1065                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1066                else:
1067                    columns = source.expression.named_selects
1068
1069            node, _ = self.scope.selected_sources.get(name) or (None, None)
1070            if isinstance(node, Scope):
1071                column_aliases = node.expression.alias_column_names
1072            elif isinstance(node, exp.Expression):
1073                column_aliases = node.alias_column_names
1074            else:
1075                column_aliases = []
1076
1077            if column_aliases:
1078                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1079                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1080                columns = [
1081                    alias or name
1082                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1083                ]
1084
1085            self._get_source_columns_cache[cache_key] = columns
1086
1087        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.