Edit on GitHub

sqlglot.optimizer.qualify_columns

   1from __future__ import annotations
   2
   3import itertools
   4import typing as t
   5
   6from sqlglot import alias, exp
   7from sqlglot.dialects.dialect import Dialect, DialectType
   8from sqlglot.errors import OptimizeError, highlight_sql
   9from sqlglot.helper import seq_get
  10from sqlglot.optimizer.annotate_types import TypeAnnotator
  11from sqlglot.optimizer.resolver import Resolver
  12from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
  13from sqlglot.optimizer.simplify import simplify_parens
  14from sqlglot.schema import Schema, ensure_schema
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot._typing import E
  18    from collections.abc import Iterator, Iterable
  19
  20
  21def qualify_columns(
  22    expression: exp.Expr,
  23    schema: dict[str, object] | Schema,
  24    expand_alias_refs: bool = True,
  25    expand_stars: bool = True,
  26    infer_schema: bool | None = None,
  27    allow_partial_qualification: bool = False,
  28    dialect: DialectType = None,
  29) -> exp.Expr:
  30    """
  31    Rewrite sqlglot AST to have fully qualified columns.
  32
  33    Example:
  34        >>> import sqlglot
  35        >>> schema = {"tbl": {"col": "INT"}}
  36        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
  37        >>> qualify_columns(expression, schema).sql()
  38        'SELECT tbl.col AS col FROM tbl'
  39
  40    Args:
  41        expression: Expr to qualify.
  42        schema: Database schema.
  43        expand_alias_refs: Whether to expand references to aliases.
  44        expand_stars: Whether to expand star queries. This is a necessary step
  45            for most of the optimizer's rules to work; do not set to False unless you
  46            know what you're doing!
  47        infer_schema: Whether to infer the schema if missing.
  48        allow_partial_qualification: Whether to allow partial qualification.
  49
  50    Returns:
  51        The qualified expression.
  52
  53    Notes:
  54        - Currently only handles a single PIVOT or UNPIVOT operator
  55    """
  56    schema = ensure_schema(schema, dialect=dialect)
  57    annotator = TypeAnnotator(schema)
  58    infer_schema = schema.empty if infer_schema is None else infer_schema
  59    dialect = schema.dialect or Dialect()
  60    pseudocolumns = dialect.PSEUDOCOLUMNS
  61
  62    for scope in traverse_scope(expression):
  63        if dialect.PREFER_CTE_ALIAS_COLUMN:
  64            pushdown_cte_alias_columns(scope)
  65
  66        scope_expression = scope.expression
  67        is_select = isinstance(scope_expression, exp.Select)
  68
  69        _separate_pseudocolumns(scope, pseudocolumns)
  70
  71        resolver = Resolver(scope, schema, infer_schema=infer_schema)
  72        _pop_table_column_aliases(scope.ctes)
  73        _pop_table_column_aliases(scope.derived_tables)
  74        using_column_tables = _expand_using(scope, resolver)
  75
  76        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
  77            _expand_alias_refs(
  78                scope,
  79                resolver,
  80                dialect,
  81                expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
  82            )
  83
  84        _convert_columns_to_dots(scope, resolver)
  85        _qualify_columns(
  86            scope,
  87            resolver,
  88            allow_partial_qualification=allow_partial_qualification,
  89        )
  90
  91        if not schema.empty and expand_alias_refs:
  92            _expand_alias_refs(scope, resolver, dialect)
  93
  94        if is_select:
  95            if expand_stars:
  96                _expand_stars(
  97                    scope,
  98                    resolver,
  99                    using_column_tables,
 100                    pseudocolumns,
 101                    annotator,
 102                )
 103            qualify_outputs(scope)
 104
 105        _expand_group_by(scope, dialect)
 106
 107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
 108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
 109        _expand_order_by_and_distinct_on(scope, resolver)
 110
 111        if dialect.ANNOTATE_ALL_SCOPES:
 112            annotator.annotate_scope(scope)
 113
 114    return expression
 115
 116
 117def validate_qualify_columns(expression: E, sql: str | None = None) -> E:
 118    """Raise an `OptimizeError` if any columns aren't qualified"""
 119    all_unqualified_columns = []
 120    for scope in traverse_scope(expression):
 121        if isinstance(scope.expression, exp.Select):
 122            unqualified_columns = scope.unqualified_columns
 123
 124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 125                column = scope.external_columns[0]
 126                for_table = f" for table: '{column.table}'" if column.table else ""
 127                line = column.this.meta.get("line")
 128                col = column.this.meta.get("col")
 129                start = column.this.meta.get("start")
 130                end = column.this.meta.get("end")
 131
 132                error_msg = f"Column '{column.name}' could not be resolved{for_table}."
 133                if line and col:
 134                    error_msg += f" Line: {line}, Col: {col}"
 135                if sql and start is not None and end is not None:
 136                    formatted_sql = highlight_sql(sql, [(start, end)])[0]
 137                    error_msg += f"\n  {formatted_sql}"
 138
 139                raise OptimizeError(error_msg)
 140
 141            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 142                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 143                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 144                # this list here to ensure those in the former category will be excluded.
 145                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 146                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 147
 148            all_unqualified_columns.extend(unqualified_columns)
 149
 150    if all_unqualified_columns:
 151        first_column = all_unqualified_columns[0]
 152        line = first_column.this.meta.get("line")
 153        col = first_column.this.meta.get("col")
 154        start = first_column.this.meta.get("start")
 155        end = first_column.this.meta.get("end")
 156
 157        error_msg = f"Ambiguous column '{first_column.name}'"
 158        if line and col:
 159            error_msg += f" (Line: {line}, Col: {col})"
 160        if sql and start is not None and end is not None:
 161            formatted_sql = highlight_sql(sql, [(start, end)])[0]
 162            error_msg += f"\n  {formatted_sql}"
 163
 164        raise OptimizeError(error_msg)
 165
 166    return expression
 167
 168
 169def _separate_pseudocolumns(scope: Scope, pseudocolumns: set[str]) -> None:
 170    if not pseudocolumns:
 171        return
 172
 173    has_pseudocolumns = False
 174    scope_expression = scope.expression
 175
 176    for column in scope.columns:
 177        name = column.name.upper()
 178        if name not in pseudocolumns:
 179            continue
 180
 181        if name != "LEVEL" or (
 182            isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect")
 183        ):
 184            column.replace(exp.Pseudocolumn(**column.args))
 185            has_pseudocolumns = True
 186
 187    if has_pseudocolumns:
 188        scope.clear_cache()
 189
 190
 191def _unpivot_columns(unpivot: exp.Pivot) -> Iterator[exp.Column]:
 192    name_columns = [
 193        field.this
 194        for field in unpivot.fields
 195        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
 196    ]
 197    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
 198
 199    return itertools.chain(name_columns, value_columns)
 200
 201
 202def _pop_table_column_aliases(derived_tables: Iterable[exp.Expr]) -> None:
 203    """
 204    Remove table column aliases.
 205
 206    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 207    """
 208    for derived_table in derived_tables:
 209        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
 210            continue
 211        table_alias = derived_table.args.get("alias")
 212        if table_alias:
 213            table_alias.set("columns", None)
 214
 215
 216def _expand_using(scope: Scope, resolver: Resolver) -> dict[str, t.Any]:
 217    columns = {}
 218
 219    def _update_source_columns(source_name: str) -> None:
 220        for column_name in resolver.get_source_columns(source_name):
 221            if column_name not in columns:
 222                columns[column_name] = source_name
 223
 224    joins = list(scope.find_all(exp.Join))
 225    names = {join.alias_or_name for join in joins}
 226    ordered = [key for key in scope.selected_sources if key not in names]
 227
 228    if names and not ordered:
 229        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
 230
 231    # Mapping of automatically joined column names to an ordered set of source names (dict).
 232    column_tables: dict[str, dict[str, t.Any]] = {}
 233
 234    for source_name in ordered:
 235        _update_source_columns(source_name)
 236
 237    for i, join in enumerate(joins):
 238        source_table = ordered[-1]
 239        if source_table:
 240            _update_source_columns(source_table)
 241
 242        join_table = join.alias_or_name
 243        ordered.append(join_table)
 244
 245        using = join.args.get("using")
 246        if not using:
 247            continue
 248
 249        join_columns = resolver.get_source_columns(join_table)
 250        conditions = []
 251        using_identifier_count = len(using)
 252        is_semi_or_anti_join = join.is_semi_or_anti_join
 253
 254        for identifier in using:
 255            identifier = identifier.name
 256            table = columns.get(identifier)
 257
 258            if not table or identifier not in join_columns:
 259                if (columns and "*" not in columns) and join_columns:
 260                    raise OptimizeError(f"Cannot automatically join: {identifier}")
 261
 262            table = table or source_table
 263
 264            if i == 0 or using_identifier_count == 1:
 265                lhs: exp.Expr = exp.column(identifier, table=table)
 266            else:
 267                coalesce_columns = [
 268                    exp.column(identifier, table=t)
 269                    for t in ordered[:-1]
 270                    if identifier in resolver.get_source_columns(t)
 271                ]
 272                if len(coalesce_columns) > 1:
 273                    lhs = exp.func("coalesce", *coalesce_columns)
 274                else:
 275                    lhs = exp.column(identifier, table=table)
 276
 277            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
 278
 279            # Set all values in the dict to None, because we only care about the key ordering
 280            tables = column_tables.setdefault(identifier, {})
 281
 282            # Do not update the dict if this was a SEMI/ANTI join in
 283            # order to avoid generating COALESCE columns for this join pair
 284            if not is_semi_or_anti_join:
 285                if table not in tables:
 286                    tables[table] = None
 287                if join_table not in tables:
 288                    tables[join_table] = None
 289
 290        join.set("using", None)
 291        join.set("on", exp.and_(*conditions, copy=False))
 292
 293    if column_tables:
 294        for column in scope.columns:
 295            if not column.table and column.name in column_tables:
 296                tables = column_tables[column.name]
 297                coalesce_args = [exp.column(column.name, table=table) for table in tables]
 298                replacement: exp.Expr = exp.func("coalesce", *coalesce_args)
 299
 300                if isinstance(column.parent, exp.Select):
 301                    # Ensure the USING column keeps its name if it's projected
 302                    replacement = alias(replacement, alias=column.name, copy=False)
 303                elif isinstance(column.parent, exp.Struct):
 304                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
 305                    replacement = exp.PropertyEQ(
 306                        this=exp.to_identifier(column.name), expression=replacement
 307                    )
 308
 309                scope.replace(column, replacement)
 310
 311    return column_tables
 312
 313
 314def _expand_alias_refs(
 315    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
 316) -> None:
 317    """
 318    Expand references to aliases.
 319    Example:
 320        SELECT y.foo AS bar, bar * 2 AS baz FROM y
 321     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
 322    """
 323    expression = scope.expression
 324
 325    if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION:
 326        return
 327
 328    alias_to_expression: dict[str, tuple[exp.Expr, int]] = {}
 329    projections = {s.alias_or_name for s in expression.selects}
 330    replaced = False
 331
 332    def replace_columns(
 333        node: exp.Expr | None, resolve_table: bool = False, literal_index: bool = False
 334    ) -> None:
 335        nonlocal replaced
 336        is_group_by = isinstance(node, exp.Group)
 337        is_having = isinstance(node, exp.Having)
 338        if not node or (expand_only_groupby and not is_group_by):
 339            return
 340
 341        for column in walk_in_scope(node, prune=lambda node: node.is_star):
 342            if not isinstance(column, exp.Column):
 343                continue
 344
 345            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
 346            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
 347            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
 348            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
 349            if expand_only_groupby and is_group_by and column.parent is not node:
 350                continue
 351
 352            skip_replace = False
 353            table = resolver.get_table(column.name) if resolve_table and not column.table else None
 354            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
 355
 356            if alias_expr:
 357                skip_replace = bool(
 358                    alias_expr.find(exp.AggFunc)
 359                    and column.find_ancestor(exp.AggFunc)
 360                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
 361                )
 362
 363                # BigQuery's having clause gets confused if an alias matches a source.
 364                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
 365                # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
 366                # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
 367                if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
 368                    skip_replace = skip_replace or any(
 369                        node.parts[0].name in projections
 370                        for node in alias_expr.find_all(exp.Column)
 371                    )
 372            elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having):
 373                column_table = table.name if table else column.table
 374                if column_table in projections:
 375                    # BigQuery's GROUP BY and HAVING clauses get confused if the column name
 376                    # matches a source name and a projection. For instance:
 377                    # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1
 378                    # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table
 379                    # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause"
 380                    column.replace(exp.to_identifier(column.name))
 381                    replaced = True
 382                    return
 383
 384            if table and (not alias_expr or skip_replace):
 385                column.set("table", table)
 386            elif not column.table and alias_expr and not skip_replace:
 387                if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and (
 388                    literal_index or resolve_table
 389                ):
 390                    if literal_index:
 391                        column.replace(exp.Literal.number(i))
 392                        replaced = True
 393                else:
 394                    replaced = True
 395                    column = column.replace(exp.paren(alias_expr))
 396                    simplified = simplify_parens(column, dialect)
 397                    if simplified is not column:
 398                        column.replace(simplified)
 399
 400    for i, projection in enumerate(expression.selects):
 401        replace_columns(projection)
 402        if isinstance(projection, exp.Alias):
 403            alias_to_expression[projection.alias] = (projection.this, i + 1)
 404
 405    parent_scope: Scope | None = scope
 406    on_right_sub_tree = False
 407    while parent_scope and not parent_scope.is_cte:
 408        if parent_scope := parent_scope.parent:
 409            if isinstance(parent_scope.expression, exp.Union):
 410                on_right_sub_tree = parent_scope.expression.right is parent_scope.expression
 411
 412    # We shouldn't expand aliases if they match the recursive CTE's columns
 413    # and we are in the recursive part (right sub tree) of the CTE
 414    if parent_scope and on_right_sub_tree:
 415        if cte := parent_scope.expression.parent:
 416            with_ = cte.find_ancestor(exp.With)
 417            if with_ and with_.recursive:
 418                for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
 419                    alias_to_expression.pop(recursive_cte_column.output_name, None)
 420
 421    replace_columns(expression.args.get("where"))
 422    replace_columns(expression.args.get("group"), literal_index=True)
 423    replace_columns(expression.args.get("having"), resolve_table=True)
 424    replace_columns(expression.args.get("qualify"), resolve_table=True)
 425
 426    if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS:
 427        for join in expression.args.get("joins") or []:
 428            replace_columns(join)
 429
 430    if replaced:
 431        scope.clear_cache()
 432
 433
 434def _expand_group_by(scope: Scope, dialect: Dialect) -> None:
 435    expression = scope.expression
 436    group = expression.args.get("group")
 437    if not group:
 438        return
 439
 440    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
 441    expression.set("group", group)
 442
 443
 444def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
 445    expression = scope.expression
 446
 447    if not isinstance(expression, exp.Selectable):
 448        return
 449
 450    # TODO (mypyc): rebind to exp.Expr to avoid Selectable trait vtable dispatch for .args
 451    expr: exp.Expr = expression
 452
 453    for modifier_key in ("order", "distinct"):
 454        modifier = expr.args.get(modifier_key)
 455        if isinstance(modifier, exp.Distinct):
 456            modifier = modifier.args.get("on")
 457
 458        if not isinstance(modifier, exp.Expr):
 459            continue
 460
 461        modifier_expressions = modifier.expressions
 462        if modifier_key == "order":
 463            modifier_expressions = [ordered.this for ordered in modifier_expressions]
 464
 465        for original, expanded in zip(
 466            modifier_expressions,
 467            _expand_positional_references(
 468                scope, modifier_expressions, resolver.dialect, alias=True
 469            ),
 470        ):
 471            for agg in original.find_all(exp.AggFunc):
 472                for col in agg.find_all(exp.Column):
 473                    if not col.table:
 474                        col.set("table", resolver.get_table(col.name))
 475
 476            original.replace(expanded)
 477
 478        if expr.args.get("group"):
 479            selects = {s.this: exp.column(s.alias_or_name) for s in expression.selects}
 480
 481            for node in modifier_expressions:
 482                node.replace(
 483                    exp.to_identifier(_select_by_pos(expression, node).alias)
 484                    if node.is_int
 485                    else selects.get(node, node)
 486                )
 487
 488
 489def _expand_positional_references(
 490    scope: Scope, expressions: Iterable[exp.Expr], dialect: Dialect, alias: bool = False
 491) -> list[exp.Expr]:
 492    new_nodes: list[exp.Expr] = []
 493    ambiguous_projections = None
 494
 495    expression = scope.expression
 496
 497    if not isinstance(expression, exp.Selectable):
 498        return new_nodes
 499
 500    for node in expressions:
 501        if node.is_int and isinstance(node, exp.Literal):
 502            select = _select_by_pos(expression, node)
 503
 504            if alias:
 505                new_nodes.append(exp.column(select.args["alias"].copy()))
 506            else:
 507                # TODO (mypyc): use a separate variable to avoid reusing `select` (Alias) with a different type
 508                select_expr: exp.Expr = select.this
 509
 510                if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
 511                    if ambiguous_projections is None:
 512                        # When a projection name is also a source name and it is referenced in the
 513                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
 514                        ambiguous_projections = {
 515                            s.alias_or_name
 516                            for s in expression.selects
 517                            if s.alias_or_name in scope.selected_sources
 518                        }
 519
 520                    ambiguous = any(
 521                        column.parts[0].name in ambiguous_projections
 522                        for column in select_expr.find_all(exp.Column)
 523                    )
 524                else:
 525                    ambiguous = False
 526
 527                if (
 528                    isinstance(select_expr, exp.CONSTANTS)
 529                    or select_expr.is_number
 530                    or select_expr.find(exp.Explode, exp.Unnest)
 531                    or ambiguous
 532                ):
 533                    new_nodes.append(node)
 534                else:
 535                    new_nodes.append(select_expr.copy())
 536        else:
 537            new_nodes.append(node)
 538
 539    return new_nodes
 540
 541
 542def _select_by_pos(expression: exp.Selectable, node: exp.Literal) -> exp.Alias:
 543    try:
 544        return expression.selects[int(node.this) - 1].assert_is(exp.Alias)
 545    except IndexError:
 546        raise OptimizeError(f"Unknown output column: {node.name}")
 547
 548
 549def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
 550    """
 551    Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`.
 552
 553    These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be
 554    normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly.
 555    """
 556    converted = False
 557    for column in itertools.chain(scope.columns, scope.stars):
 558        if isinstance(column, exp.Dot):
 559            continue
 560
 561        column_table: str | exp.Identifier | None = column.table
 562        dot_parts = column.meta.pop("dot_parts", [])
 563        if (
 564            column_table
 565            and column_table not in scope.sources
 566            and (
 567                not scope.parent
 568                or column_table not in scope.parent.sources
 569                or not scope.is_correlated_subquery
 570            )
 571        ):
 572            root, *parts = column.parts
 573
 574            if isinstance(root, exp.Identifier) and root.name in scope.sources:
 575                # The struct is already qualified, but we still need to change the AST
 576                column_table = root
 577                root, *parts = parts
 578                was_qualified = True
 579            else:
 580                column_table = resolver.get_table(root.name)
 581                was_qualified = False
 582
 583            if column_table:
 584                converted = True
 585                new_column = exp.column(root, table=column_table)
 586
 587                if dot_parts:
 588                    # Remove the actual column parts from the rest of dot parts
 589                    new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :]
 590
 591                column.replace(exp.Dot.build([new_column, *parts]))
 592
 593    if converted:
 594        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
 595        # a `for column in scope.columns` iteration, even though they shouldn't be
 596        scope.clear_cache()
 597
 598
 599def _qualify_columns(
 600    scope: Scope,
 601    resolver: Resolver,
 602    allow_partial_qualification: bool,
 603) -> None:
 604    """Disambiguate columns, ensuring each column specifies a source"""
 605    for column in scope.columns:
 606        column_table = column.table
 607        column_name = column.name
 608
 609        if column_table and column_table in scope.sources:
 610            source_columns = resolver.get_source_columns(column_table)
 611            if (
 612                not allow_partial_qualification
 613                and source_columns
 614                and column_name not in source_columns
 615                and "*" not in source_columns
 616            ):
 617                raise OptimizeError(f"Unknown column: {column_name}")
 618
 619        if not column_table:
 620            if scope.pivots and not column.find_ancestor(exp.Pivot):
 621                # If the column is under the Pivot expression, we need to qualify it
 622                # using the name of the pivoted source instead of the pivot's alias
 623                column.set("table", exp.to_identifier(scope.pivots[0].alias))
 624                continue
 625
 626            # column_table can be a '' because bigquery unnest has no table alias
 627            table = resolver.get_table(column)
 628
 629            if (
 630                table
 631                and isinstance(source := scope.sources.get(table.name), Scope)
 632                and id(column) in source.column_index
 633            ):
 634                continue
 635
 636            if table:
 637                column.set("table", table)
 638            elif (
 639                resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS
 640                and len(column.parts) == 1
 641                and column_name in scope.selected_sources
 642            ):
 643                # BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records
 644                scope.replace(column, exp.TableColumn(this=column.this))
 645
 646    for pivot in scope.pivots:
 647        for column in pivot.find_all(exp.Column):
 648            if not column.table and column.name in resolver.all_columns:
 649                table = resolver.get_table(column.name)
 650                if table:
 651                    column.set("table", table)
 652
 653
 654def _expand_struct_stars_no_parens(
 655    expression: exp.Dot,
 656) -> list[exp.Alias]:
 657    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
 658
 659    dot_column = expression.find(exp.Column)
 660    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DType.STRUCT):
 661        return []
 662
 663    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
 664    dot_column = dot_column.copy()
 665    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
 666
 667    # First part is the table name and last part is the star so they can be dropped
 668    dot_parts = expression.parts[1:-1]
 669
 670    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
 671    for part in dot_parts[1:]:
 672        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 673            # Unable to expand star unless all fields are named
 674            if not isinstance(field.this, exp.Identifier):
 675                return []
 676
 677            if field.name == part.name and field.kind.is_type(exp.DType.STRUCT):
 678                starting_struct = field
 679                break
 680        else:
 681            # There is no matching field in the struct
 682            return []
 683
 684    taken_names = set()
 685    new_selections = []
 686
 687    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
 688        name = field.name
 689
 690        # Ambiguous or anonymous fields can't be expanded
 691        if name in taken_names or not isinstance(field.this, exp.Identifier):
 692            return []
 693
 694        taken_names.add(name)
 695
 696        this = field.this.copy()
 697        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
 698        new_column = exp.column(
 699            t.cast(exp.Identifier, root),
 700            table=dot_column.args.get("table"),
 701            fields=t.cast(list[exp.Identifier], parts),
 702        )
 703        new_selections.append(alias(new_column, this, copy=False).assert_is(exp.Alias))
 704
 705    return new_selections
 706
 707
 708def _expand_struct_stars_with_parens(expression: exp.Dot) -> list[exp.Alias]:
 709    """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
 710
 711    # it is not (<sub_exp>).* pattern, which means we can't expand
 712    if not isinstance(expression.this, exp.Paren):
 713        return []
 714
 715    # find column definition to get data-type
 716    dot_column = expression.find(exp.Column)
 717    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DType.STRUCT):
 718        return []
 719
 720    parent = dot_column.parent
 721    starting_struct = dot_column.type
 722
 723    # walk up AST and down into struct definition in sync
 724    while parent is not None:
 725        if isinstance(parent, exp.Paren):
 726            parent = parent.parent
 727            continue
 728
 729        # if parent is not a dot, then something is wrong
 730        if not isinstance(parent, exp.Dot):
 731            return []
 732
 733        # if the rhs of the dot is star we are done
 734        rhs = parent.right
 735        if isinstance(rhs, exp.Star):
 736            break
 737
 738        # if it is not identifier, then something is wrong
 739        if not isinstance(rhs, exp.Identifier):
 740            return []
 741
 742        # Check if current rhs identifier is in struct
 743        matched = False
 744        for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 745            if struct_field_def.name == rhs.name:
 746                matched = True
 747                starting_struct = struct_field_def.kind  # update struct
 748                break
 749
 750        if not matched:
 751            return []
 752
 753        parent = parent.parent
 754
 755    # build new aliases to expand star
 756    new_selections = []
 757
 758    # fetch the outermost parentheses for new aliaes
 759    outer_paren = expression.this
 760
 761    for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
 762        new_identifier = struct_field_def.this.copy()
 763        new_dot = exp.Dot.build([outer_paren.copy(), new_identifier])
 764        new_alias = alias(new_dot, new_identifier, copy=False).assert_is(exp.Alias)
 765        new_selections.append(new_alias)
 766
 767    return new_selections
 768
 769
 770def _expand_stars(
 771    scope: Scope,
 772    resolver: Resolver,
 773    using_column_tables: dict[str, t.Any],
 774    pseudocolumns: set[str],
 775    annotator: TypeAnnotator,
 776) -> None:
 777    """Expand stars to lists of column selections"""
 778
 779    new_selections: list[exp.Expr] = []
 780    except_columns: dict[int, set[str]] = {}
 781    replace_columns: dict[int, dict[str, exp.Alias]] = {}
 782    rename_columns: dict[int, dict[str, str]] = {}
 783
 784    coalesced_columns = set()
 785    dialect = resolver.dialect
 786
 787    pivot_output_columns = None
 788    pivot_exclude_columns: set[str] = set()
 789
 790    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
 791    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
 792        if pivot.unpivot:
 793            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
 794
 795            for field in pivot.fields:
 796                if isinstance(field, exp.In):
 797                    pivot_exclude_columns.update(
 798                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
 799                    )
 800
 801        else:
 802            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
 803
 804            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
 805            if not pivot_output_columns:
 806                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
 807
 808    if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any(
 809        isinstance(col, exp.Dot) for col in scope.stars
 810    ):
 811        # Found struct expansion, annotate scope ahead of time
 812        annotator.annotate_scope(scope)
 813
 814    scope_expression = scope.expression
 815
 816    if not isinstance(scope_expression, exp.Selectable):
 817        return
 818
 819    for expression in scope_expression.selects:
 820        tables: list[str] = []
 821        if isinstance(expression, exp.Star):
 822            tables.extend(scope.selected_sources)
 823            _add_except_columns(expression, tables, except_columns)
 824            _add_replace_columns(expression, tables, replace_columns)
 825            _add_rename_columns(expression, tables, rename_columns)
 826        elif expression.is_star:
 827            if isinstance(expression, exp.Column):
 828                tables.append(expression.table)
 829                _add_except_columns(expression.this, tables, except_columns)
 830                _add_replace_columns(expression.this, tables, replace_columns)
 831                _add_rename_columns(expression.this, tables, rename_columns)
 832            elif isinstance(expression, exp.Dot):
 833                if (
 834                    dialect.SUPPORTS_STRUCT_STAR_EXPANSION
 835                    and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS
 836                ):
 837                    struct_fields = _expand_struct_stars_no_parens(expression)
 838                    if struct_fields:
 839                        new_selections.extend(struct_fields)
 840                        continue
 841                elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS:
 842                    struct_fields = _expand_struct_stars_with_parens(expression)
 843                    if struct_fields:
 844                        new_selections.extend(struct_fields)
 845                        continue
 846
 847        if not tables:
 848            new_selections.append(expression)
 849            continue
 850
 851        for table in tables:
 852            if table not in scope.sources:
 853                raise OptimizeError(f"Unknown table: {table}")
 854
 855            columns = resolver.get_source_columns(table, only_visible=True)
 856            columns = columns or scope.outer_columns
 857
 858            if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR:
 859                columns = [name for name in columns if name.upper() not in pseudocolumns]
 860
 861            if not columns or "*" in columns:
 862                return
 863
 864            table_id = id(table)
 865            columns_to_exclude = except_columns.get(table_id) or set()
 866            renamed_columns = rename_columns.get(table_id, {})
 867            replaced_columns = replace_columns.get(table_id, {})
 868
 869            if pivot:
 870                if pivot_output_columns and pivot_exclude_columns:
 871                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
 872                    pivot_columns.extend(pivot_output_columns)
 873                else:
 874                    pivot_columns = pivot.alias_column_names
 875
 876                if pivot_columns:
 877                    new_selections.extend(
 878                        alias(exp.column(name, table=pivot.alias), name, copy=False)
 879                        for name in pivot_columns
 880                        if name not in columns_to_exclude
 881                    )
 882                    continue
 883
 884            for name in columns:
 885                if name in columns_to_exclude or name in coalesced_columns:
 886                    continue
 887                if name in using_column_tables and table in using_column_tables[name]:
 888                    coalesced_columns.add(name)
 889                    # TODO (mypyc): use a separate variable to avoid reusing `tables` (list) with dict type
 890                    using_tables = using_column_tables[name]
 891                    coalesce_args = [exp.column(name, table=table) for table in using_tables]
 892
 893                    new_selections.append(
 894                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
 895                    )
 896                else:
 897                    alias_ = renamed_columns.get(name, name)
 898                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
 899                    new_selections.append(
 900                        alias(selection_expr, alias_, copy=False)
 901                        if alias_ != name
 902                        else selection_expr
 903                    )
 904
 905    # Ensures we don't overwrite the initial selections with an empty list
 906    if new_selections and isinstance(scope_expression, exp.Select):
 907        scope_expression.set("expressions", new_selections)
 908
 909
 910def _add_except_columns(expression: exp.Expr, tables, except_columns: dict[int, set[str]]) -> None:
 911    except_ = expression.args.get("except_")
 912
 913    if not except_:
 914        return
 915
 916    columns = {e.name for e in except_}
 917
 918    for table in tables:
 919        except_columns[id(table)] = columns
 920
 921
 922def _add_rename_columns(
 923    expression: exp.Expr, tables, rename_columns: dict[int, dict[str, str]]
 924) -> None:
 925    rename = expression.args.get("rename")
 926
 927    if not rename:
 928        return
 929
 930    columns = {e.this.name: e.alias for e in rename}
 931
 932    for table in tables:
 933        rename_columns[id(table)] = columns
 934
 935
 936def _add_replace_columns(
 937    expression: exp.Expr, tables, replace_columns: dict[int, dict[str, exp.Alias]]
 938) -> None:
 939    replace = expression.args.get("replace")
 940
 941    if not replace:
 942        return
 943
 944    columns = {e.alias: e for e in replace}
 945
 946    for table in tables:
 947        replace_columns[id(table)] = columns
 948
 949
 950def qualify_outputs(scope_or_expression: Scope | exp.Expr) -> None:
 951    """Ensure all output columns are aliased"""
 952    if isinstance(scope_or_expression, exp.Expr):
 953        scope = build_scope(scope_or_expression)
 954        if not isinstance(scope, Scope):
 955            return
 956    else:
 957        scope = scope_or_expression
 958
 959    expression = scope.expression
 960
 961    if not isinstance(expression, exp.Selectable):
 962        return
 963
 964    new_selections = []
 965
 966    for i, (selection, aliased_column) in enumerate(
 967        itertools.zip_longest(expression.selects, scope.outer_columns)
 968    ):
 969        if selection is None or isinstance(selection, exp.QueryTransform):
 970            break
 971
 972        if isinstance(selection, exp.Subquery):
 973            if not selection.output_name:
 974                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
 975        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
 976            selection = alias(
 977                selection,
 978                alias=selection.output_name or f"_col_{i}",
 979                copy=False,
 980            )
 981        if aliased_column:
 982            selection.set("alias", exp.to_identifier(aliased_column))
 983
 984        new_selections.append(selection)
 985
 986    if new_selections and isinstance(expression, exp.Select):
 987        expression.set("expressions", new_selections)
 988
 989
 990def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
 991    """Makes sure all identifiers that need to be quoted are quoted."""
 992    return expression.transform(
 993        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
 994    )  # type: ignore
 995
 996
 997def pushdown_cte_alias_columns(scope: Scope) -> None:
 998    """
 999    Pushes down the CTE alias columns into the projection,
1000
1001    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
1002
1003    Args:
1004        scope: Scope to find ctes to pushdown aliases.
1005    """
1006    for cte in scope.ctes:
1007        if cte.alias_column_names and isinstance(cte.this, exp.Select):
1008            new_expressions = []
1009            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
1010                if isinstance(projection, exp.Alias):
1011                    projection.set("alias", exp.to_identifier(_alias))
1012                else:
1013                    projection = alias(projection, alias=_alias)
1014                new_expressions.append(projection)
1015            cte.this.set("expressions", new_expressions)
def qualify_columns( expression: sqlglot.expressions.core.Expr, schema: dict[str, object] | sqlglot.schema.Schema, expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: bool | None = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.core.Expr:
 22def qualify_columns(
 23    expression: exp.Expr,
 24    schema: dict[str, object] | Schema,
 25    expand_alias_refs: bool = True,
 26    expand_stars: bool = True,
 27    infer_schema: bool | None = None,
 28    allow_partial_qualification: bool = False,
 29    dialect: DialectType = None,
 30) -> exp.Expr:
 31    """
 32    Rewrite sqlglot AST to have fully qualified columns.
 33
 34    Example:
 35        >>> import sqlglot
 36        >>> schema = {"tbl": {"col": "INT"}}
 37        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 38        >>> qualify_columns(expression, schema).sql()
 39        'SELECT tbl.col AS col FROM tbl'
 40
 41    Args:
 42        expression: Expr to qualify.
 43        schema: Database schema.
 44        expand_alias_refs: Whether to expand references to aliases.
 45        expand_stars: Whether to expand star queries. This is a necessary step
 46            for most of the optimizer's rules to work; do not set to False unless you
 47            know what you're doing!
 48        infer_schema: Whether to infer the schema if missing.
 49        allow_partial_qualification: Whether to allow partial qualification.
 50
 51    Returns:
 52        The qualified expression.
 53
 54    Notes:
 55        - Currently only handles a single PIVOT or UNPIVOT operator
 56    """
 57    schema = ensure_schema(schema, dialect=dialect)
 58    annotator = TypeAnnotator(schema)
 59    infer_schema = schema.empty if infer_schema is None else infer_schema
 60    dialect = schema.dialect or Dialect()
 61    pseudocolumns = dialect.PSEUDOCOLUMNS
 62
 63    for scope in traverse_scope(expression):
 64        if dialect.PREFER_CTE_ALIAS_COLUMN:
 65            pushdown_cte_alias_columns(scope)
 66
 67        scope_expression = scope.expression
 68        is_select = isinstance(scope_expression, exp.Select)
 69
 70        _separate_pseudocolumns(scope, pseudocolumns)
 71
 72        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 73        _pop_table_column_aliases(scope.ctes)
 74        _pop_table_column_aliases(scope.derived_tables)
 75        using_column_tables = _expand_using(scope, resolver)
 76
 77        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 78            _expand_alias_refs(
 79                scope,
 80                resolver,
 81                dialect,
 82                expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
 83            )
 84
 85        _convert_columns_to_dots(scope, resolver)
 86        _qualify_columns(
 87            scope,
 88            resolver,
 89            allow_partial_qualification=allow_partial_qualification,
 90        )
 91
 92        if not schema.empty and expand_alias_refs:
 93            _expand_alias_refs(scope, resolver, dialect)
 94
 95        if is_select:
 96            if expand_stars:
 97                _expand_stars(
 98                    scope,
 99                    resolver,
100                    using_column_tables,
101                    pseudocolumns,
102                    annotator,
103                )
104            qualify_outputs(scope)
105
106        _expand_group_by(scope, dialect)
107
108        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
109        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
110        _expand_order_by_and_distinct_on(scope, resolver)
111
112        if dialect.ANNOTATE_ALL_SCOPES:
113            annotator.annotate_scope(scope)
114
115    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expr to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E, sql: str | None = None) -> ~E:
118def validate_qualify_columns(expression: E, sql: str | None = None) -> E:
119    """Raise an `OptimizeError` if any columns aren't qualified"""
120    all_unqualified_columns = []
121    for scope in traverse_scope(expression):
122        if isinstance(scope.expression, exp.Select):
123            unqualified_columns = scope.unqualified_columns
124
125            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
126                column = scope.external_columns[0]
127                for_table = f" for table: '{column.table}'" if column.table else ""
128                line = column.this.meta.get("line")
129                col = column.this.meta.get("col")
130                start = column.this.meta.get("start")
131                end = column.this.meta.get("end")
132
133                error_msg = f"Column '{column.name}' could not be resolved{for_table}."
134                if line and col:
135                    error_msg += f" Line: {line}, Col: {col}"
136                if sql and start is not None and end is not None:
137                    formatted_sql = highlight_sql(sql, [(start, end)])[0]
138                    error_msg += f"\n  {formatted_sql}"
139
140                raise OptimizeError(error_msg)
141
142            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
143                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
144                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
145                # this list here to ensure those in the former category will be excluded.
146                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
147                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
148
149            all_unqualified_columns.extend(unqualified_columns)
150
151    if all_unqualified_columns:
152        first_column = all_unqualified_columns[0]
153        line = first_column.this.meta.get("line")
154        col = first_column.this.meta.get("col")
155        start = first_column.this.meta.get("start")
156        end = first_column.this.meta.get("end")
157
158        error_msg = f"Ambiguous column '{first_column.name}'"
159        if line and col:
160            error_msg += f" (Line: {line}, Col: {col})"
161        if sql and start is not None and end is not None:
162            formatted_sql = highlight_sql(sql, [(start, end)])[0]
163            error_msg += f"\n  {formatted_sql}"
164
165        raise OptimizeError(error_msg)
166
167    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.core.Expr) -> None:
951def qualify_outputs(scope_or_expression: Scope | exp.Expr) -> None:
952    """Ensure all output columns are aliased"""
953    if isinstance(scope_or_expression, exp.Expr):
954        scope = build_scope(scope_or_expression)
955        if not isinstance(scope, Scope):
956            return
957    else:
958        scope = scope_or_expression
959
960    expression = scope.expression
961
962    if not isinstance(expression, exp.Selectable):
963        return
964
965    new_selections = []
966
967    for i, (selection, aliased_column) in enumerate(
968        itertools.zip_longest(expression.selects, scope.outer_columns)
969    ):
970        if selection is None or isinstance(selection, exp.QueryTransform):
971            break
972
973        if isinstance(selection, exp.Subquery):
974            if not selection.output_name:
975                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
976        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
977            selection = alias(
978                selection,
979                alias=selection.output_name or f"_col_{i}",
980                copy=False,
981            )
982        if aliased_column:
983            selection.set("alias", exp.to_identifier(aliased_column))
984
985        new_selections.append(selection)
986
987    if new_selections and isinstance(expression, exp.Select):
988        expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
991def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
992    """Makes sure all identifiers that need to be quoted are quoted."""
993    return expression.transform(
994        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
995    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns(scope: sqlglot.optimizer.scope.Scope) -> None:
 998def pushdown_cte_alias_columns(scope: Scope) -> None:
 999    """
1000    Pushes down the CTE alias columns into the projection,
1001
1002    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
1003
1004    Args:
1005        scope: Scope to find ctes to pushdown aliases.
1006    """
1007    for cte in scope.ctes:
1008        if cte.alias_column_names and isinstance(cte.this, exp.Select):
1009            new_expressions = []
1010            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
1011                if isinstance(projection, exp.Alias):
1012                    projection.set("alias", exp.to_identifier(_alias))
1013                else:
1014                    projection = alias(projection, alias=_alias)
1015                new_expressions.append(projection)
1016            cte.this.set("expressions", new_expressions)

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Arguments:
  • scope: Scope to find ctes to pushdown aliases.