Edit on GitHub

sqlglot.optimizer.qualify_columns

   1from __future__ import annotations
   2
   3import itertools
   4import typing as t
   5
   6from sqlglot import alias, exp
   7from sqlglot.dialects.dialect import Dialect, DialectType
   8from sqlglot.errors import OptimizeError
   9from sqlglot.helper import seq_get, SingleValuedMapping
  10from sqlglot.optimizer.annotate_types import TypeAnnotator
  11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
  12from sqlglot.optimizer.simplify import simplify_parens
  13from sqlglot.schema import Schema, ensure_schema
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot._typing import E
  17
  18
  19def qualify_columns(
  20    expression: exp.Expression,
  21    schema: t.Dict | Schema,
  22    expand_alias_refs: bool = True,
  23    expand_stars: bool = True,
  24    infer_schema: t.Optional[bool] = None,
  25    allow_partial_qualification: bool = False,
  26    dialect: DialectType = None,
  27) -> exp.Expression:
  28    """
  29    Rewrite sqlglot AST to have fully qualified columns.
  30
  31    Example:
  32        >>> import sqlglot
  33        >>> schema = {"tbl": {"col": "INT"}}
  34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
  35        >>> qualify_columns(expression, schema).sql()
  36        'SELECT tbl.col AS col FROM tbl'
  37
  38    Args:
  39        expression: Expression to qualify.
  40        schema: Database schema.
  41        expand_alias_refs: Whether to expand references to aliases.
  42        expand_stars: Whether to expand star queries. This is a necessary step
  43            for most of the optimizer's rules to work; do not set to False unless you
  44            know what you're doing!
  45        infer_schema: Whether to infer the schema if missing.
  46        allow_partial_qualification: Whether to allow partial qualification.
  47
  48    Returns:
  49        The qualified expression.
  50
  51    Notes:
  52        - Currently only handles a single PIVOT or UNPIVOT operator
  53    """
  54    schema = ensure_schema(schema, dialect=dialect)
  55    annotator = TypeAnnotator(schema)
  56    infer_schema = schema.empty if infer_schema is None else infer_schema
  57    dialect = Dialect.get_or_raise(schema.dialect)
  58    pseudocolumns = dialect.PSEUDOCOLUMNS
  59    bigquery = dialect == "bigquery"
  60
  61    for scope in traverse_scope(expression):
  62        scope_expression = scope.expression
  63        is_select = isinstance(scope_expression, exp.Select)
  64
  65        if is_select and scope_expression.args.get("connect"):
  66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
  67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
  68            scope_expression.transform(
  69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
  70                copy=False,
  71            )
  72            scope.clear_cache()
  73
  74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
  75        _pop_table_column_aliases(scope.ctes)
  76        _pop_table_column_aliases(scope.derived_tables)
  77        using_column_tables = _expand_using(scope, resolver)
  78
  79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
  80            _expand_alias_refs(
  81                scope,
  82                resolver,
  83                dialect,
  84                expand_only_groupby=bigquery,
  85            )
  86
  87        _convert_columns_to_dots(scope, resolver)
  88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
  89
  90        if not schema.empty and expand_alias_refs:
  91            _expand_alias_refs(scope, resolver, dialect)
  92
  93        if is_select:
  94            if expand_stars:
  95                _expand_stars(
  96                    scope,
  97                    resolver,
  98                    using_column_tables,
  99                    pseudocolumns,
 100                    annotator,
 101                )
 102            qualify_outputs(scope)
 103
 104        _expand_group_by(scope, dialect)
 105
 106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
 107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
 108        _expand_order_by_and_distinct_on(scope, resolver)
 109
 110        if bigquery:
 111            annotator.annotate_scope(scope)
 112
 113    return expression
 114
 115
 116def validate_qualify_columns(expression: E) -> E:
 117    """Raise an `OptimizeError` if any columns aren't qualified"""
 118    all_unqualified_columns = []
 119    for scope in traverse_scope(expression):
 120        if isinstance(scope.expression, exp.Select):
 121            unqualified_columns = scope.unqualified_columns
 122
 123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 124                column = scope.external_columns[0]
 125                for_table = f" for table: '{column.table}'" if column.table else ""
 126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 127
 128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 131                # this list here to ensure those in the former category will be excluded.
 132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 134
 135            all_unqualified_columns.extend(unqualified_columns)
 136
 137    if all_unqualified_columns:
 138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
 139
 140    return expression
 141
 142
 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
 144    name_columns = [
 145        field.this
 146        for field in unpivot.fields
 147        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
 148    ]
 149    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
 150
 151    return itertools.chain(name_columns, value_columns)
 152
 153
 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 155    """
 156    Remove table column aliases.
 157
 158    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 159    """
 160    for derived_table in derived_tables:
 161        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
 162            continue
 163        table_alias = derived_table.args.get("alias")
 164        if table_alias:
 165            table_alias.args.pop("columns", None)
 166
 167
 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 169    columns = {}
 170
 171    def _update_source_columns(source_name: str) -> None:
 172        for column_name in resolver.get_source_columns(source_name):
 173            if column_name not in columns:
 174                columns[column_name] = source_name
 175
 176    joins = list(scope.find_all(exp.Join))
 177    names = {join.alias_or_name for join in joins}
 178    ordered = [key for key in scope.selected_sources if key not in names]
 179
 180    if names and not ordered:
 181        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
 182
 183    # Mapping of automatically joined column names to an ordered set of source names (dict).
 184    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
 185
 186    for source_name in ordered:
 187        _update_source_columns(source_name)
 188
 189    for i, join in enumerate(joins):
 190        source_table = ordered[-1]
 191        if source_table:
 192            _update_source_columns(source_table)
 193
 194        join_table = join.alias_or_name
 195        ordered.append(join_table)
 196
 197        using = join.args.get("using")
 198        if not using:
 199            continue
 200
 201        join_columns = resolver.get_source_columns(join_table)
 202        conditions = []
 203        using_identifier_count = len(using)
 204        is_semi_or_anti_join = join.is_semi_or_anti_join
 205
 206        for identifier in using:
 207            identifier = identifier.name
 208            table = columns.get(identifier)
 209
 210            if not table or identifier not in join_columns:
 211                if (columns and "*" not in columns) and join_columns:
 212                    raise OptimizeError(f"Cannot automatically join: {identifier}")
 213
 214            table = table or source_table
 215
 216            if i == 0 or using_identifier_count == 1:
 217                lhs: exp.Expression = exp.column(identifier, table=table)
 218            else:
 219                coalesce_columns = [
 220                    exp.column(identifier, table=t)
 221                    for t in ordered[:-1]
 222                    if identifier in resolver.get_source_columns(t)
 223                ]
 224                if len(coalesce_columns) > 1:
 225                    lhs = exp.func("coalesce", *coalesce_columns)
 226                else:
 227                    lhs = exp.column(identifier, table=table)
 228
 229            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
 230
 231            # Set all values in the dict to None, because we only care about the key ordering
 232            tables = column_tables.setdefault(identifier, {})
 233
 234            # Do not update the dict if this was a SEMI/ANTI join in
 235            # order to avoid generating COALESCE columns for this join pair
 236            if not is_semi_or_anti_join:
 237                if table not in tables:
 238                    tables[table] = None
 239                if join_table not in tables:
 240                    tables[join_table] = None
 241
 242        join.args.pop("using")
 243        join.set("on", exp.and_(*conditions, copy=False))
 244
 245    if column_tables:
 246        for column in scope.columns:
 247            if not column.table and column.name in column_tables:
 248                tables = column_tables[column.name]
 249                coalesce_args = [exp.column(column.name, table=table) for table in tables]
 250                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
 251
 252                if isinstance(column.parent, exp.Select):
 253                    # Ensure the USING column keeps its name if it's projected
 254                    replacement = alias(replacement, alias=column.name, copy=False)
 255                elif isinstance(column.parent, exp.Struct):
 256                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
 257                    replacement = exp.PropertyEQ(
 258                        this=exp.to_identifier(column.name), expression=replacement
 259                    )
 260
 261                scope.replace(column, replacement)
 262
 263    return column_tables
 264
 265
 266def _expand_alias_refs(
 267    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
 268) -> None:
 269    """
 270    Expand references to aliases.
 271    Example:
 272        SELECT y.foo AS bar, bar * 2 AS baz FROM y
 273     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
 274    """
 275    expression = scope.expression
 276
 277    if not isinstance(expression, exp.Select) or dialect == "oracle":
 278        return
 279
 280    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
 281    projections = {s.alias_or_name for s in expression.selects}
 282    is_bigquery = dialect == "bigquery"
 283
 284    def replace_columns(
 285        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
 286    ) -> None:
 287        is_group_by = isinstance(node, exp.Group)
 288        is_having = isinstance(node, exp.Having)
 289        if not node or (expand_only_groupby and not is_group_by):
 290            return
 291
 292        for column in walk_in_scope(node, prune=lambda node: node.is_star):
 293            if not isinstance(column, exp.Column):
 294                continue
 295
 296            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
 297            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
 298            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
 299            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
 300            if expand_only_groupby and is_group_by and column.parent is not node:
 301                continue
 302
 303            skip_replace = False
 304            table = resolver.get_table(column.name) if resolve_table and not column.table else None
 305            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
 306
 307            if alias_expr:
 308                skip_replace = bool(
 309                    alias_expr.find(exp.AggFunc)
 310                    and column.find_ancestor(exp.AggFunc)
 311                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
 312                )
 313
 314                # BigQuery's having clause gets confused if an alias matches a source.
 315                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
 316                # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
 317                # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
 318                if is_having and is_bigquery:
 319                    skip_replace = skip_replace or any(
 320                        node.parts[0].name in projections
 321                        for node in alias_expr.find_all(exp.Column)
 322                    )
 323            elif is_bigquery and (is_group_by or is_having):
 324                column_table = table.name if table else column.table
 325                if column_table in projections:
 326                    # BigQuery's GROUP BY and HAVING clauses get confused if the column name
 327                    # matches a source name and a projection. For instance:
 328                    # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1
 329                    # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table
 330                    # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause"
 331                    column.replace(exp.to_identifier(column.name))
 332                    return
 333
 334            if table and (not alias_expr or skip_replace):
 335                column.set("table", table)
 336            elif not column.table and alias_expr and not skip_replace:
 337                if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and (
 338                    literal_index or resolve_table
 339                ):
 340                    if literal_index:
 341                        column.replace(exp.Literal.number(i))
 342                else:
 343                    column = column.replace(exp.paren(alias_expr))
 344                    simplified = simplify_parens(column, dialect)
 345                    if simplified is not column:
 346                        column.replace(simplified)
 347
 348    for i, projection in enumerate(expression.selects):
 349        replace_columns(projection)
 350        if isinstance(projection, exp.Alias):
 351            alias_to_expression[projection.alias] = (projection.this, i + 1)
 352
 353    parent_scope = scope
 354    while parent_scope.is_union:
 355        parent_scope = parent_scope.parent
 356
 357    # We shouldn't expand aliases if they match the recursive CTE's columns
 358    if parent_scope.is_cte:
 359        cte = parent_scope.expression.parent
 360        if cte.find_ancestor(exp.With).recursive:
 361            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
 362                alias_to_expression.pop(recursive_cte_column.output_name, None)
 363
 364    replace_columns(expression.args.get("where"))
 365    replace_columns(expression.args.get("group"), literal_index=True)
 366    replace_columns(expression.args.get("having"), resolve_table=True)
 367    replace_columns(expression.args.get("qualify"), resolve_table=True)
 368
 369    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
 370    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
 371    if dialect == "snowflake":
 372        for join in expression.args.get("joins") or []:
 373            replace_columns(join)
 374
 375    scope.clear_cache()
 376
 377
 378def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
 379    expression = scope.expression
 380    group = expression.args.get("group")
 381    if not group:
 382        return
 383
 384    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
 385    expression.set("group", group)
 386
 387
 388def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
 389    for modifier_key in ("order", "distinct"):
 390        modifier = scope.expression.args.get(modifier_key)
 391        if isinstance(modifier, exp.Distinct):
 392            modifier = modifier.args.get("on")
 393
 394        if not isinstance(modifier, exp.Expression):
 395            continue
 396
 397        modifier_expressions = modifier.expressions
 398        if modifier_key == "order":
 399            modifier_expressions = [ordered.this for ordered in modifier_expressions]
 400
 401        for original, expanded in zip(
 402            modifier_expressions,
 403            _expand_positional_references(
 404                scope, modifier_expressions, resolver.schema.dialect, alias=True
 405            ),
 406        ):
 407            for agg in original.find_all(exp.AggFunc):
 408                for col in agg.find_all(exp.Column):
 409                    if not col.table:
 410                        col.set("table", resolver.get_table(col.name))
 411
 412            original.replace(expanded)
 413
 414        if scope.expression.args.get("group"):
 415            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
 416
 417            for expression in modifier_expressions:
 418                expression.replace(
 419                    exp.to_identifier(_select_by_pos(scope, expression).alias)
 420                    if expression.is_int
 421                    else selects.get(expression, expression)
 422                )
 423
 424
 425def _expand_positional_references(
 426    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
 427) -> t.List[exp.Expression]:
 428    new_nodes: t.List[exp.Expression] = []
 429    ambiguous_projections = None
 430
 431    for node in expressions:
 432        if node.is_int:
 433            select = _select_by_pos(scope, t.cast(exp.Literal, node))
 434
 435            if alias:
 436                new_nodes.append(exp.column(select.args["alias"].copy()))
 437            else:
 438                select = select.this
 439
 440                if dialect == "bigquery":
 441                    if ambiguous_projections is None:
 442                        # When a projection name is also a source name and it is referenced in the
 443                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
 444                        ambiguous_projections = {
 445                            s.alias_or_name
 446                            for s in scope.expression.selects
 447                            if s.alias_or_name in scope.selected_sources
 448                        }
 449
 450                    ambiguous = any(
 451                        column.parts[0].name in ambiguous_projections
 452                        for column in select.find_all(exp.Column)
 453                    )
 454                else:
 455                    ambiguous = False
 456
 457                if (
 458                    isinstance(select, exp.CONSTANTS)
 459                    or select.is_number
 460                    or select.find(exp.Explode, exp.Unnest)
 461                    or ambiguous
 462                ):
 463                    new_nodes.append(node)
 464                else:
 465                    new_nodes.append(select.copy())
 466        else:
 467            new_nodes.append(node)
 468
 469    return new_nodes
 470
 471
 472def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
 473    try:
 474        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
 475    except IndexError:
 476        raise OptimizeError(f"Unknown output column: {node.name}")
 477
 478
 479def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
 480    """
 481    Converts `Column` instances that represent struct field lookup into chained `Dots`.
 482
 483    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
 484    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
 485    """
 486    converted = False
 487    for column in itertools.chain(scope.columns, scope.stars):
 488        if isinstance(column, exp.Dot):
 489            continue
 490
 491        column_table: t.Optional[str | exp.Identifier] = column.table
 492        if (
 493            column_table
 494            and column_table not in scope.sources
 495            and (
 496                not scope.parent
 497                or column_table not in scope.parent.sources
 498                or not scope.is_correlated_subquery
 499            )
 500        ):
 501            root, *parts = column.parts
 502
 503            if root.name in scope.sources:
 504                # The struct is already qualified, but we still need to change the AST
 505                column_table = root
 506                root, *parts = parts
 507            else:
 508                column_table = resolver.get_table(root.name)
 509
 510            if column_table:
 511                converted = True
 512                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
 513
 514    if converted:
 515        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
 516        # a `for column in scope.columns` iteration, even though they shouldn't be
 517        scope.clear_cache()
 518
 519
 520def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
 521    """Disambiguate columns, ensuring each column specifies a source"""
 522    for column in scope.columns:
 523        column_table = column.table
 524        column_name = column.name
 525
 526        if column_table and column_table in scope.sources:
 527            source_columns = resolver.get_source_columns(column_table)
 528            if (
 529                not allow_partial_qualification
 530                and source_columns
 531                and column_name not in source_columns
 532                and "*" not in source_columns
 533            ):
 534                raise OptimizeError(f"Unknown column: {column_name}")
 535
 536        if not column_table:
 537            if scope.pivots and not column.find_ancestor(exp.Pivot):
 538                # If the column is under the Pivot expression, we need to qualify it
 539                # using the name of the pivoted source instead of the pivot's alias
 540                column.set("table", exp.to_identifier(scope.pivots[0].alias))
 541                continue
 542
 543            # column_table can be a '' because bigquery unnest has no table alias
 544            column_table = resolver.get_table(column_name)
 545            if column_table:
 546                column.set("table", column_table)
 547            elif (
 548                resolver.schema.dialect == "bigquery"
 549                and len(column.parts) == 1
 550                and column_name in scope.selected_sources
 551            ):
 552                # BigQuery allows tables to be referenced as columns, treating them as structs
 553                scope.replace(column, exp.TableColumn(this=column.this))
 554
 555    for pivot in scope.pivots:
 556        for column in pivot.find_all(exp.Column):
 557            if not column.table and column.name in resolver.all_columns:
 558                column_table = resolver.get_table(column.name)
 559                if column_table:
 560                    column.set("table", column_table)
 561
 562
 563def _expand_struct_stars_bigquery(
 564    expression: exp.Dot,
 565) -> t.List[exp.Alias]:
 566    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
 567
 568    dot_column = expression.find(exp.Column)
 569    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
 570        return []
 571
 572    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
 573    dot_column = dot_column.copy()
 574    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
 575
 576    # First part is the table name and last part is the star so they can be dropped
 577    dot_parts = expression.parts[1:-1]
 578
 579    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
 580    for part in dot_parts[1:]:
 581        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 582            # Unable to expand star unless all fields are named
 583            if not isinstance(field.this, exp.Identifier):
 584                return []
 585
 586            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
 587                starting_struct = field
 588                break
 589        else:
 590            # There is no matching field in the struct
 591            return []
 592
 593    taken_names = set()
 594    new_selections = []
 595
 596    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 597        name = field.name
 598
 599        # Ambiguous or anonymous fields can't be expanded
 600        if name in taken_names or not isinstance(field.this, exp.Identifier):
 601            return []
 602
 603        taken_names.add(name)
 604
 605        this = field.this.copy()
 606        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
 607        new_column = exp.column(
 608            t.cast(exp.Identifier, root),
 609            table=dot_column.args.get("table"),
 610            fields=t.cast(t.List[exp.Identifier], parts),
 611        )
 612        new_selections.append(alias(new_column, this, copy=False))
 613
 614    return new_selections
 615
 616
 617def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]:
 618    """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
 619
 620    # it is not (<sub_exp>).* pattern, which means we can't expand
 621    if not isinstance(expression.this, exp.Paren):
 622        return []
 623
 624    # find column definition to get data-type
 625    dot_column = expression.find(exp.Column)
 626    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
 627        return []
 628
 629    parent = dot_column.parent
 630    starting_struct = dot_column.type
 631
 632    # walk up AST and down into struct definition in sync
 633    while parent is not None:
 634        if isinstance(parent, exp.Paren):
 635            parent = parent.parent
 636            continue
 637
 638        # if parent is not a dot, then something is wrong
 639        if not isinstance(parent, exp.Dot):
 640            return []
 641
 642        # if the rhs of the dot is star we are done
 643        rhs = parent.right
 644        if isinstance(rhs, exp.Star):
 645            break
 646
 647        # if it is not identifier, then something is wrong
 648        if not isinstance(rhs, exp.Identifier):
 649            return []
 650
 651        # Check if current rhs identifier is in struct
 652        matched = False
 653        for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 654            if struct_field_def.name == rhs.name:
 655                matched = True
 656                starting_struct = struct_field_def.kind  # update struct
 657                break
 658
 659        if not matched:
 660            return []
 661
 662        parent = parent.parent
 663
 664    # build new aliases to expand star
 665    new_selections = []
 666
 667    # fetch the outermost parentheses for new aliaes
 668    outer_paren = expression.this
 669
 670    for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 671        new_identifier = struct_field_def.this.copy()
 672        new_dot = exp.Dot.build([outer_paren.copy(), new_identifier])
 673        new_alias = alias(new_dot, new_identifier, copy=False)
 674        new_selections.append(new_alias)
 675
 676    return new_selections
 677
 678
 679def _expand_stars(
 680    scope: Scope,
 681    resolver: Resolver,
 682    using_column_tables: t.Dict[str, t.Any],
 683    pseudocolumns: t.Set[str],
 684    annotator: TypeAnnotator,
 685) -> None:
 686    """Expand stars to lists of column selections"""
 687
 688    new_selections: t.List[exp.Expression] = []
 689    except_columns: t.Dict[int, t.Set[str]] = {}
 690    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
 691    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
 692
 693    coalesced_columns = set()
 694    dialect = resolver.schema.dialect
 695
 696    pivot_output_columns = None
 697    pivot_exclude_columns: t.Set[str] = set()
 698
 699    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
 700    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
 701        if pivot.unpivot:
 702            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
 703
 704            for field in pivot.fields:
 705                if isinstance(field, exp.In):
 706                    pivot_exclude_columns.update(
 707                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
 708                    )
 709
 710        else:
 711            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
 712
 713            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
 714            if not pivot_output_columns:
 715                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
 716
 717    is_bigquery = dialect == "bigquery"
 718    is_risingwave = dialect == "risingwave"
 719
 720    if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars):
 721        # Found struct expansion, annotate scope ahead of time
 722        annotator.annotate_scope(scope)
 723
 724    for expression in scope.expression.selects:
 725        tables = []
 726        if isinstance(expression, exp.Star):
 727            tables.extend(scope.selected_sources)
 728            _add_except_columns(expression, tables, except_columns)
 729            _add_replace_columns(expression, tables, replace_columns)
 730            _add_rename_columns(expression, tables, rename_columns)
 731        elif expression.is_star:
 732            if not isinstance(expression, exp.Dot):
 733                tables.append(expression.table)
 734                _add_except_columns(expression.this, tables, except_columns)
 735                _add_replace_columns(expression.this, tables, replace_columns)
 736                _add_rename_columns(expression.this, tables, rename_columns)
 737            elif is_bigquery:
 738                struct_fields = _expand_struct_stars_bigquery(expression)
 739                if struct_fields:
 740                    new_selections.extend(struct_fields)
 741                    continue
 742            elif is_risingwave:
 743                struct_fields = _expand_struct_stars_risingwave(expression)
 744                if struct_fields:
 745                    new_selections.extend(struct_fields)
 746                    continue
 747
 748        if not tables:
 749            new_selections.append(expression)
 750            continue
 751
 752        for table in tables:
 753            if table not in scope.sources:
 754                raise OptimizeError(f"Unknown table: {table}")
 755
 756            columns = resolver.get_source_columns(table, only_visible=True)
 757            columns = columns or scope.outer_columns
 758
 759            if pseudocolumns:
 760                columns = [name for name in columns if name.upper() not in pseudocolumns]
 761
 762            if not columns or "*" in columns:
 763                return
 764
 765            table_id = id(table)
 766            columns_to_exclude = except_columns.get(table_id) or set()
 767            renamed_columns = rename_columns.get(table_id, {})
 768            replaced_columns = replace_columns.get(table_id, {})
 769
 770            if pivot:
 771                if pivot_output_columns and pivot_exclude_columns:
 772                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
 773                    pivot_columns.extend(pivot_output_columns)
 774                else:
 775                    pivot_columns = pivot.alias_column_names
 776
 777                if pivot_columns:
 778                    new_selections.extend(
 779                        alias(exp.column(name, table=pivot.alias), name, copy=False)
 780                        for name in pivot_columns
 781                        if name not in columns_to_exclude
 782                    )
 783                    continue
 784
 785            for name in columns:
 786                if name in columns_to_exclude or name in coalesced_columns:
 787                    continue
 788                if name in using_column_tables and table in using_column_tables[name]:
 789                    coalesced_columns.add(name)
 790                    tables = using_column_tables[name]
 791                    coalesce_args = [exp.column(name, table=table) for table in tables]
 792
 793                    new_selections.append(
 794                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
 795                    )
 796                else:
 797                    alias_ = renamed_columns.get(name, name)
 798                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
 799                    new_selections.append(
 800                        alias(selection_expr, alias_, copy=False)
 801                        if alias_ != name
 802                        else selection_expr
 803                    )
 804
 805    # Ensures we don't overwrite the initial selections with an empty list
 806    if new_selections and isinstance(scope.expression, exp.Select):
 807        scope.expression.set("expressions", new_selections)
 808
 809
 810def _add_except_columns(
 811    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
 812) -> None:
 813    except_ = expression.args.get("except")
 814
 815    if not except_:
 816        return
 817
 818    columns = {e.name for e in except_}
 819
 820    for table in tables:
 821        except_columns[id(table)] = columns
 822
 823
 824def _add_rename_columns(
 825    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
 826) -> None:
 827    rename = expression.args.get("rename")
 828
 829    if not rename:
 830        return
 831
 832    columns = {e.this.name: e.alias for e in rename}
 833
 834    for table in tables:
 835        rename_columns[id(table)] = columns
 836
 837
 838def _add_replace_columns(
 839    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
 840) -> None:
 841    replace = expression.args.get("replace")
 842
 843    if not replace:
 844        return
 845
 846    columns = {e.alias: e for e in replace}
 847
 848    for table in tables:
 849        replace_columns[id(table)] = columns
 850
 851
 852def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
 853    """Ensure all output columns are aliased"""
 854    if isinstance(scope_or_expression, exp.Expression):
 855        scope = build_scope(scope_or_expression)
 856        if not isinstance(scope, Scope):
 857            return
 858    else:
 859        scope = scope_or_expression
 860
 861    new_selections = []
 862    for i, (selection, aliased_column) in enumerate(
 863        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
 864    ):
 865        if selection is None or isinstance(selection, exp.QueryTransform):
 866            break
 867
 868        if isinstance(selection, exp.Subquery):
 869            if not selection.output_name:
 870                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
 871        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
 872            selection = alias(
 873                selection,
 874                alias=selection.output_name or f"_col_{i}",
 875                copy=False,
 876            )
 877        if aliased_column:
 878            selection.set("alias", exp.to_identifier(aliased_column))
 879
 880        new_selections.append(selection)
 881
 882    if new_selections and isinstance(scope.expression, exp.Select):
 883        scope.expression.set("expressions", new_selections)
 884
 885
 886def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
 887    """Makes sure all identifiers that need to be quoted are quoted."""
 888    return expression.transform(
 889        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
 890    )  # type: ignore
 891
 892
 893def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
 894    """
 895    Pushes down the CTE alias columns into the projection,
 896
 897    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
 898
 899    Example:
 900        >>> import sqlglot
 901        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
 902        >>> pushdown_cte_alias_columns(expression).sql()
 903        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
 904
 905    Args:
 906        expression: Expression to pushdown.
 907
 908    Returns:
 909        The expression with the CTE aliases pushed down into the projection.
 910    """
 911    for cte in expression.find_all(exp.CTE):
 912        if cte.alias_column_names:
 913            new_expressions = []
 914            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
 915                if isinstance(projection, exp.Alias):
 916                    projection.set("alias", _alias)
 917                else:
 918                    projection = alias(projection, alias=_alias)
 919                new_expressions.append(projection)
 920            cte.this.set("expressions", new_expressions)
 921
 922    return expression
 923
 924
 925class Resolver:
 926    """
 927    Helper for resolving columns.
 928
 929    This is a class so we can lazily load some things and easily share them across functions.
 930    """
 931
 932    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 933        self.scope = scope
 934        self.schema = schema
 935        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 936        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 937        self._all_columns: t.Optional[t.Set[str]] = None
 938        self._infer_schema = infer_schema
 939        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 940
 941    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 942        """
 943        Get the table for a column name.
 944
 945        Args:
 946            column_name: The column name to find the table for.
 947        Returns:
 948            The table name if it can be found/inferred.
 949        """
 950        if self._unambiguous_columns is None:
 951            self._unambiguous_columns = self._get_unambiguous_columns(
 952                self._get_all_source_columns()
 953            )
 954
 955        table_name = self._unambiguous_columns.get(column_name)
 956
 957        if not table_name and self._infer_schema:
 958            sources_without_schema = tuple(
 959                source
 960                for source, columns in self._get_all_source_columns().items()
 961                if not columns or "*" in columns
 962            )
 963            if len(sources_without_schema) == 1:
 964                table_name = sources_without_schema[0]
 965
 966        if table_name not in self.scope.selected_sources:
 967            return exp.to_identifier(table_name)
 968
 969        node, _ = self.scope.selected_sources.get(table_name)
 970
 971        if isinstance(node, exp.Query):
 972            while node and node.alias != table_name:
 973                node = node.parent
 974
 975        node_alias = node.args.get("alias")
 976        if node_alias:
 977            return exp.to_identifier(node_alias.this)
 978
 979        return exp.to_identifier(table_name)
 980
 981    @property
 982    def all_columns(self) -> t.Set[str]:
 983        """All available columns of all sources in this scope"""
 984        if self._all_columns is None:
 985            self._all_columns = {
 986                column for columns in self._get_all_source_columns().values() for column in columns
 987            }
 988        return self._all_columns
 989
 990    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 991        if isinstance(expression, exp.Select):
 992            return expression.named_selects
 993        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 994            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
 995            return self.get_source_columns_from_set_op(expression.this)
 996        if not isinstance(expression, exp.SetOperation):
 997            raise OptimizeError(f"Unknown set operation: {expression}")
 998
 999        set_op = expression
1000
1001        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1002        on_column_list = set_op.args.get("on")
1003
1004        if on_column_list:
1005            # The resulting columns are the columns in the ON clause:
1006            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1007            columns = [col.name for col in on_column_list]
1008        elif set_op.side or set_op.kind:
1009            side = set_op.side
1010            kind = set_op.kind
1011
1012            # Visit the children UNIONs (if any) in a post-order traversal
1013            left = self.get_source_columns_from_set_op(set_op.left)
1014            right = self.get_source_columns_from_set_op(set_op.right)
1015
1016            # We use dict.fromkeys to deduplicate keys and maintain insertion order
1017            if side == "LEFT":
1018                columns = left
1019            elif side == "FULL":
1020                columns = list(dict.fromkeys(left + right))
1021            elif kind == "INNER":
1022                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1023        else:
1024            columns = set_op.named_selects
1025
1026        return columns
1027
1028    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1029        """Resolve the source columns for a given source `name`."""
1030        cache_key = (name, only_visible)
1031        if cache_key not in self._get_source_columns_cache:
1032            if name not in self.scope.sources:
1033                raise OptimizeError(f"Unknown table: {name}")
1034
1035            source = self.scope.sources[name]
1036
1037            if isinstance(source, exp.Table):
1038                columns = self.schema.column_names(source, only_visible)
1039            elif isinstance(source, Scope) and isinstance(
1040                source.expression, (exp.Values, exp.Unnest)
1041            ):
1042                columns = source.expression.named_selects
1043
1044                # in bigquery, unnest structs are automatically scoped as tables, so you can
1045                # directly select a struct field in a query.
1046                # this handles the case where the unnest is statically defined.
1047                if self.schema.dialect == "bigquery":
1048                    if source.expression.is_type(exp.DataType.Type.STRUCT):
1049                        for k in source.expression.type.expressions:  # type: ignore
1050                            columns.append(k.name)
1051            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1052                columns = self.get_source_columns_from_set_op(source.expression)
1053
1054            else:
1055                select = seq_get(source.expression.selects, 0)
1056
1057                if isinstance(select, exp.QueryTransform):
1058                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1059                    schema = select.args.get("schema")
1060                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1061                else:
1062                    columns = source.expression.named_selects
1063
1064            node, _ = self.scope.selected_sources.get(name) or (None, None)
1065            if isinstance(node, Scope):
1066                column_aliases = node.expression.alias_column_names
1067            elif isinstance(node, exp.Expression):
1068                column_aliases = node.alias_column_names
1069            else:
1070                column_aliases = []
1071
1072            if column_aliases:
1073                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1074                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1075                columns = [
1076                    alias or name
1077                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1078                ]
1079
1080            self._get_source_columns_cache[cache_key] = columns
1081
1082        return self._get_source_columns_cache[cache_key]
1083
1084    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1085        if self._source_columns is None:
1086            self._source_columns = {
1087                source_name: self.get_source_columns(source_name)
1088                for source_name, source in itertools.chain(
1089                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1090                )
1091            }
1092        return self._source_columns
1093
1094    def _get_unambiguous_columns(
1095        self, source_columns: t.Dict[str, t.Sequence[str]]
1096    ) -> t.Mapping[str, str]:
1097        """
1098        Find all the unambiguous columns in sources.
1099
1100        Args:
1101            source_columns: Mapping of names to source columns.
1102
1103        Returns:
1104            Mapping of column name to source name.
1105        """
1106        if not source_columns:
1107            return {}
1108
1109        source_columns_pairs = list(source_columns.items())
1110
1111        first_table, first_columns = source_columns_pairs[0]
1112
1113        if len(source_columns_pairs) == 1:
1114            # Performance optimization - avoid copying first_columns if there is only one table.
1115            return SingleValuedMapping(first_columns, first_table)
1116
1117        unambiguous_columns = {col: first_table for col in first_columns}
1118        all_columns = set(unambiguous_columns)
1119
1120        for table, columns in source_columns_pairs[1:]:
1121            unique = set(columns)
1122            ambiguous = all_columns.intersection(unique)
1123            all_columns.update(columns)
1124
1125            for column in ambiguous:
1126                unambiguous_columns.pop(column, None)
1127            for column in unique.difference(ambiguous):
1128                unambiguous_columns[column] = table
1129
1130        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27    dialect: DialectType = None,
 28) -> exp.Expression:
 29    """
 30    Rewrite sqlglot AST to have fully qualified columns.
 31
 32    Example:
 33        >>> import sqlglot
 34        >>> schema = {"tbl": {"col": "INT"}}
 35        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 36        >>> qualify_columns(expression, schema).sql()
 37        'SELECT tbl.col AS col FROM tbl'
 38
 39    Args:
 40        expression: Expression to qualify.
 41        schema: Database schema.
 42        expand_alias_refs: Whether to expand references to aliases.
 43        expand_stars: Whether to expand star queries. This is a necessary step
 44            for most of the optimizer's rules to work; do not set to False unless you
 45            know what you're doing!
 46        infer_schema: Whether to infer the schema if missing.
 47        allow_partial_qualification: Whether to allow partial qualification.
 48
 49    Returns:
 50        The qualified expression.
 51
 52    Notes:
 53        - Currently only handles a single PIVOT or UNPIVOT operator
 54    """
 55    schema = ensure_schema(schema, dialect=dialect)
 56    annotator = TypeAnnotator(schema)
 57    infer_schema = schema.empty if infer_schema is None else infer_schema
 58    dialect = Dialect.get_or_raise(schema.dialect)
 59    pseudocolumns = dialect.PSEUDOCOLUMNS
 60    bigquery = dialect == "bigquery"
 61
 62    for scope in traverse_scope(expression):
 63        scope_expression = scope.expression
 64        is_select = isinstance(scope_expression, exp.Select)
 65
 66        if is_select and scope_expression.args.get("connect"):
 67            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 68            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 69            scope_expression.transform(
 70                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 71                copy=False,
 72            )
 73            scope.clear_cache()
 74
 75        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 76        _pop_table_column_aliases(scope.ctes)
 77        _pop_table_column_aliases(scope.derived_tables)
 78        using_column_tables = _expand_using(scope, resolver)
 79
 80        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 81            _expand_alias_refs(
 82                scope,
 83                resolver,
 84                dialect,
 85                expand_only_groupby=bigquery,
 86            )
 87
 88        _convert_columns_to_dots(scope, resolver)
 89        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106
107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
109        _expand_order_by_and_distinct_on(scope, resolver)
110
111        if bigquery:
112            annotator.annotate_scope(scope)
113
114    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E:
118    """Raise an `OptimizeError` if any columns aren't qualified"""
119    all_unqualified_columns = []
120    for scope in traverse_scope(expression):
121        if isinstance(scope.expression, exp.Select):
122            unqualified_columns = scope.unqualified_columns
123
124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125                column = scope.external_columns[0]
126                for_table = f" for table: '{column.table}'" if column.table else ""
127                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
128
129            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
130                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
131                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
132                # this list here to ensure those in the former category will be excluded.
133                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
134                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
135
136            all_unqualified_columns.extend(unqualified_columns)
137
138    if all_unqualified_columns:
139        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
140
141    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
853def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
854    """Ensure all output columns are aliased"""
855    if isinstance(scope_or_expression, exp.Expression):
856        scope = build_scope(scope_or_expression)
857        if not isinstance(scope, Scope):
858            return
859    else:
860        scope = scope_or_expression
861
862    new_selections = []
863    for i, (selection, aliased_column) in enumerate(
864        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
865    ):
866        if selection is None or isinstance(selection, exp.QueryTransform):
867            break
868
869        if isinstance(selection, exp.Subquery):
870            if not selection.output_name:
871                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
872        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
873            selection = alias(
874                selection,
875                alias=selection.output_name or f"_col_{i}",
876                copy=False,
877            )
878        if aliased_column:
879            selection.set("alias", exp.to_identifier(aliased_column))
880
881        new_selections.append(selection)
882
883    if new_selections and isinstance(scope.expression, exp.Select):
884        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
887def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
888    """Makes sure all identifiers that need to be quoted are quoted."""
889    return expression.transform(
890        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
891    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
894def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
895    """
896    Pushes down the CTE alias columns into the projection,
897
898    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
899
900    Example:
901        >>> import sqlglot
902        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
903        >>> pushdown_cte_alias_columns(expression).sql()
904        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
905
906    Args:
907        expression: Expression to pushdown.
908
909    Returns:
910        The expression with the CTE aliases pushed down into the projection.
911    """
912    for cte in expression.find_all(exp.CTE):
913        if cte.alias_column_names:
914            new_expressions = []
915            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
916                if isinstance(projection, exp.Alias):
917                    projection.set("alias", _alias)
918                else:
919                    projection = alias(projection, alias=_alias)
920                new_expressions.append(projection)
921            cte.this.set("expressions", new_expressions)
922
923    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
 926class Resolver:
 927    """
 928    Helper for resolving columns.
 929
 930    This is a class so we can lazily load some things and easily share them across functions.
 931    """
 932
 933    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 934        self.scope = scope
 935        self.schema = schema
 936        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 937        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 938        self._all_columns: t.Optional[t.Set[str]] = None
 939        self._infer_schema = infer_schema
 940        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 941
 942    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
 943        """
 944        Get the table for a column name.
 945
 946        Args:
 947            column_name: The column name to find the table for.
 948        Returns:
 949            The table name if it can be found/inferred.
 950        """
 951        if self._unambiguous_columns is None:
 952            self._unambiguous_columns = self._get_unambiguous_columns(
 953                self._get_all_source_columns()
 954            )
 955
 956        table_name = self._unambiguous_columns.get(column_name)
 957
 958        if not table_name and self._infer_schema:
 959            sources_without_schema = tuple(
 960                source
 961                for source, columns in self._get_all_source_columns().items()
 962                if not columns or "*" in columns
 963            )
 964            if len(sources_without_schema) == 1:
 965                table_name = sources_without_schema[0]
 966
 967        if table_name not in self.scope.selected_sources:
 968            return exp.to_identifier(table_name)
 969
 970        node, _ = self.scope.selected_sources.get(table_name)
 971
 972        if isinstance(node, exp.Query):
 973            while node and node.alias != table_name:
 974                node = node.parent
 975
 976        node_alias = node.args.get("alias")
 977        if node_alias:
 978            return exp.to_identifier(node_alias.this)
 979
 980        return exp.to_identifier(table_name)
 981
 982    @property
 983    def all_columns(self) -> t.Set[str]:
 984        """All available columns of all sources in this scope"""
 985        if self._all_columns is None:
 986            self._all_columns = {
 987                column for columns in self._get_all_source_columns().values() for column in columns
 988            }
 989        return self._all_columns
 990
 991    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 992        if isinstance(expression, exp.Select):
 993            return expression.named_selects
 994        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 995            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
 996            return self.get_source_columns_from_set_op(expression.this)
 997        if not isinstance(expression, exp.SetOperation):
 998            raise OptimizeError(f"Unknown set operation: {expression}")
 999
1000        set_op = expression
1001
1002        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1003        on_column_list = set_op.args.get("on")
1004
1005        if on_column_list:
1006            # The resulting columns are the columns in the ON clause:
1007            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1008            columns = [col.name for col in on_column_list]
1009        elif set_op.side or set_op.kind:
1010            side = set_op.side
1011            kind = set_op.kind
1012
1013            # Visit the children UNIONs (if any) in a post-order traversal
1014            left = self.get_source_columns_from_set_op(set_op.left)
1015            right = self.get_source_columns_from_set_op(set_op.right)
1016
1017            # We use dict.fromkeys to deduplicate keys and maintain insertion order
1018            if side == "LEFT":
1019                columns = left
1020            elif side == "FULL":
1021                columns = list(dict.fromkeys(left + right))
1022            elif kind == "INNER":
1023                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1024        else:
1025            columns = set_op.named_selects
1026
1027        return columns
1028
1029    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1030        """Resolve the source columns for a given source `name`."""
1031        cache_key = (name, only_visible)
1032        if cache_key not in self._get_source_columns_cache:
1033            if name not in self.scope.sources:
1034                raise OptimizeError(f"Unknown table: {name}")
1035
1036            source = self.scope.sources[name]
1037
1038            if isinstance(source, exp.Table):
1039                columns = self.schema.column_names(source, only_visible)
1040            elif isinstance(source, Scope) and isinstance(
1041                source.expression, (exp.Values, exp.Unnest)
1042            ):
1043                columns = source.expression.named_selects
1044
1045                # in bigquery, unnest structs are automatically scoped as tables, so you can
1046                # directly select a struct field in a query.
1047                # this handles the case where the unnest is statically defined.
1048                if self.schema.dialect == "bigquery":
1049                    if source.expression.is_type(exp.DataType.Type.STRUCT):
1050                        for k in source.expression.type.expressions:  # type: ignore
1051                            columns.append(k.name)
1052            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1053                columns = self.get_source_columns_from_set_op(source.expression)
1054
1055            else:
1056                select = seq_get(source.expression.selects, 0)
1057
1058                if isinstance(select, exp.QueryTransform):
1059                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1060                    schema = select.args.get("schema")
1061                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1062                else:
1063                    columns = source.expression.named_selects
1064
1065            node, _ = self.scope.selected_sources.get(name) or (None, None)
1066            if isinstance(node, Scope):
1067                column_aliases = node.expression.alias_column_names
1068            elif isinstance(node, exp.Expression):
1069                column_aliases = node.alias_column_names
1070            else:
1071                column_aliases = []
1072
1073            if column_aliases:
1074                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1075                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1076                columns = [
1077                    alias or name
1078                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1079                ]
1080
1081            self._get_source_columns_cache[cache_key] = columns
1082
1083        return self._get_source_columns_cache[cache_key]
1084
1085    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
1086        if self._source_columns is None:
1087            self._source_columns = {
1088                source_name: self.get_source_columns(source_name)
1089                for source_name, source in itertools.chain(
1090                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
1091                )
1092            }
1093        return self._source_columns
1094
1095    def _get_unambiguous_columns(
1096        self, source_columns: t.Dict[str, t.Sequence[str]]
1097    ) -> t.Mapping[str, str]:
1098        """
1099        Find all the unambiguous columns in sources.
1100
1101        Args:
1102            source_columns: Mapping of names to source columns.
1103
1104        Returns:
1105            Mapping of column name to source name.
1106        """
1107        if not source_columns:
1108            return {}
1109
1110        source_columns_pairs = list(source_columns.items())
1111
1112        first_table, first_columns = source_columns_pairs[0]
1113
1114        if len(source_columns_pairs) == 1:
1115            # Performance optimization - avoid copying first_columns if there is only one table.
1116            return SingleValuedMapping(first_columns, first_table)
1117
1118        unambiguous_columns = {col: first_table for col in first_columns}
1119        all_columns = set(unambiguous_columns)
1120
1121        for table, columns in source_columns_pairs[1:]:
1122            unique = set(columns)
1123            ambiguous = all_columns.intersection(unique)
1124            all_columns.update(columns)
1125
1126            for column in ambiguous:
1127                unambiguous_columns.pop(column, None)
1128            for column in unique.difference(ambiguous):
1129                unambiguous_columns[column] = table
1130
1131        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
933    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
934        self.scope = scope
935        self.schema = schema
936        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
937        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
938        self._all_columns: t.Optional[t.Set[str]] = None
939        self._infer_schema = infer_schema
940        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
942    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
943        """
944        Get the table for a column name.
945
946        Args:
947            column_name: The column name to find the table for.
948        Returns:
949            The table name if it can be found/inferred.
950        """
951        if self._unambiguous_columns is None:
952            self._unambiguous_columns = self._get_unambiguous_columns(
953                self._get_all_source_columns()
954            )
955
956        table_name = self._unambiguous_columns.get(column_name)
957
958        if not table_name and self._infer_schema:
959            sources_without_schema = tuple(
960                source
961                for source, columns in self._get_all_source_columns().items()
962                if not columns or "*" in columns
963            )
964            if len(sources_without_schema) == 1:
965                table_name = sources_without_schema[0]
966
967        if table_name not in self.scope.selected_sources:
968            return exp.to_identifier(table_name)
969
970        node, _ = self.scope.selected_sources.get(table_name)
971
972        if isinstance(node, exp.Query):
973            while node and node.alias != table_name:
974                node = node.parent
975
976        node_alias = node.args.get("alias")
977        if node_alias:
978            return exp.to_identifier(node_alias.this)
979
980        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
982    @property
983    def all_columns(self) -> t.Set[str]:
984        """All available columns of all sources in this scope"""
985        if self._all_columns is None:
986            self._all_columns = {
987                column for columns in self._get_all_source_columns().values() for column in columns
988            }
989        return self._all_columns

All available columns of all sources in this scope

def get_source_columns_from_set_op(self, expression: sqlglot.expressions.Expression) -> List[str]:
 991    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 992        if isinstance(expression, exp.Select):
 993            return expression.named_selects
 994        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 995            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
 996            return self.get_source_columns_from_set_op(expression.this)
 997        if not isinstance(expression, exp.SetOperation):
 998            raise OptimizeError(f"Unknown set operation: {expression}")
 999
1000        set_op = expression
1001
1002        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
1003        on_column_list = set_op.args.get("on")
1004
1005        if on_column_list:
1006            # The resulting columns are the columns in the ON clause:
1007            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
1008            columns = [col.name for col in on_column_list]
1009        elif set_op.side or set_op.kind:
1010            side = set_op.side
1011            kind = set_op.kind
1012
1013            # Visit the children UNIONs (if any) in a post-order traversal
1014            left = self.get_source_columns_from_set_op(set_op.left)
1015            right = self.get_source_columns_from_set_op(set_op.right)
1016
1017            # We use dict.fromkeys to deduplicate keys and maintain insertion order
1018            if side == "LEFT":
1019                columns = left
1020            elif side == "FULL":
1021                columns = list(dict.fromkeys(left + right))
1022            elif kind == "INNER":
1023                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
1024        else:
1025            columns = set_op.named_selects
1026
1027        return columns
def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
1029    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
1030        """Resolve the source columns for a given source `name`."""
1031        cache_key = (name, only_visible)
1032        if cache_key not in self._get_source_columns_cache:
1033            if name not in self.scope.sources:
1034                raise OptimizeError(f"Unknown table: {name}")
1035
1036            source = self.scope.sources[name]
1037
1038            if isinstance(source, exp.Table):
1039                columns = self.schema.column_names(source, only_visible)
1040            elif isinstance(source, Scope) and isinstance(
1041                source.expression, (exp.Values, exp.Unnest)
1042            ):
1043                columns = source.expression.named_selects
1044
1045                # in bigquery, unnest structs are automatically scoped as tables, so you can
1046                # directly select a struct field in a query.
1047                # this handles the case where the unnest is statically defined.
1048                if self.schema.dialect == "bigquery":
1049                    if source.expression.is_type(exp.DataType.Type.STRUCT):
1050                        for k in source.expression.type.expressions:  # type: ignore
1051                            columns.append(k.name)
1052            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
1053                columns = self.get_source_columns_from_set_op(source.expression)
1054
1055            else:
1056                select = seq_get(source.expression.selects, 0)
1057
1058                if isinstance(select, exp.QueryTransform):
1059                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
1060                    schema = select.args.get("schema")
1061                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
1062                else:
1063                    columns = source.expression.named_selects
1064
1065            node, _ = self.scope.selected_sources.get(name) or (None, None)
1066            if isinstance(node, Scope):
1067                column_aliases = node.expression.alias_column_names
1068            elif isinstance(node, exp.Expression):
1069                column_aliases = node.alias_column_names
1070            else:
1071                column_aliases = []
1072
1073            if column_aliases:
1074                # If the source's columns are aliased, their aliases shadow the corresponding column names.
1075                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
1076                columns = [
1077                    alias or name
1078                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
1079                ]
1080
1081            self._get_source_columns_cache[cache_key] = columns
1082
1083        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.