Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError, highlight_sql
  9from sqlglot.helper import seq_get
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.resolver import Resolver
 12from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 13from sqlglot.optimizer.simplify import simplify_parens
 14from sqlglot.schema import Schema, ensure_schema
 15
 16if t.TYPE_CHECKING:
 17    from sqlglot._typing import E
 18
 19
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27    dialect: DialectType = None,
 28) -> exp.Expression:
 29    """
 30    Rewrite sqlglot AST to have fully qualified columns.
 31
 32    Example:
 33        >>> import sqlglot
 34        >>> schema = {"tbl": {"col": "INT"}}
 35        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 36        >>> qualify_columns(expression, schema).sql()
 37        'SELECT tbl.col AS col FROM tbl'
 38
 39    Args:
 40        expression: Expression to qualify.
 41        schema: Database schema.
 42        expand_alias_refs: Whether to expand references to aliases.
 43        expand_stars: Whether to expand star queries. This is a necessary step
 44            for most of the optimizer's rules to work; do not set to False unless you
 45            know what you're doing!
 46        infer_schema: Whether to infer the schema if missing.
 47        allow_partial_qualification: Whether to allow partial qualification.
 48
 49    Returns:
 50        The qualified expression.
 51
 52    Notes:
 53        - Currently only handles a single PIVOT or UNPIVOT operator
 54    """
 55    schema = ensure_schema(schema, dialect=dialect)
 56    annotator = TypeAnnotator(schema)
 57    infer_schema = schema.empty if infer_schema is None else infer_schema
 58    dialect = schema.dialect or Dialect()
 59    pseudocolumns = dialect.PSEUDOCOLUMNS
 60
 61    for scope in traverse_scope(expression):
 62        if dialect.PREFER_CTE_ALIAS_COLUMN:
 63            pushdown_cte_alias_columns(scope)
 64
 65        scope_expression = scope.expression
 66        is_select = isinstance(scope_expression, exp.Select)
 67
 68        _separate_pseudocolumns(scope, pseudocolumns)
 69
 70        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 71        _pop_table_column_aliases(scope.ctes)
 72        _pop_table_column_aliases(scope.derived_tables)
 73        using_column_tables = _expand_using(scope, resolver)
 74
 75        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 76            _expand_alias_refs(
 77                scope,
 78                resolver,
 79                dialect,
 80                expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
 81            )
 82
 83        _convert_columns_to_dots(scope, resolver)
 84        _qualify_columns(
 85            scope,
 86            resolver,
 87            allow_partial_qualification=allow_partial_qualification,
 88        )
 89
 90        if not schema.empty and expand_alias_refs:
 91            _expand_alias_refs(scope, resolver, dialect)
 92
 93        if is_select:
 94            if expand_stars:
 95                _expand_stars(
 96                    scope,
 97                    resolver,
 98                    using_column_tables,
 99                    pseudocolumns,
100                    annotator,
101                )
102            qualify_outputs(scope)
103
104        _expand_group_by(scope, dialect)
105
106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
108        _expand_order_by_and_distinct_on(scope, resolver)
109
110        if dialect.ANNOTATE_ALL_SCOPES:
111            annotator.annotate_scope(scope)
112
113    return expression
114
115
116def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E:
117    """Raise an `OptimizeError` if any columns aren't qualified"""
118    all_unqualified_columns = []
119    for scope in traverse_scope(expression):
120        if isinstance(scope.expression, exp.Select):
121            unqualified_columns = scope.unqualified_columns
122
123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
124                column = scope.external_columns[0]
125                for_table = f" for table: '{column.table}'" if column.table else ""
126                line = column.this.meta.get("line")
127                col = column.this.meta.get("col")
128                start = column.this.meta.get("start")
129                end = column.this.meta.get("end")
130
131                error_msg = f"Column '{column.name}' could not be resolved{for_table}."
132                if line and col:
133                    error_msg += f" Line: {line}, Col: {col}"
134                if sql and start is not None and end is not None:
135                    formatted_sql = highlight_sql(sql, [(start, end)])[0]
136                    error_msg += f"\n  {formatted_sql}"
137
138                raise OptimizeError(error_msg)
139
140            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
141                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
142                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
143                # this list here to ensure those in the former category will be excluded.
144                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
145                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
146
147            all_unqualified_columns.extend(unqualified_columns)
148
149    if all_unqualified_columns:
150        first_column = all_unqualified_columns[0]
151        line = first_column.this.meta.get("line")
152        col = first_column.this.meta.get("col")
153        start = first_column.this.meta.get("start")
154        end = first_column.this.meta.get("end")
155
156        error_msg = f"Ambiguous column '{first_column.name}'"
157        if line and col:
158            error_msg += f" (Line: {line}, Col: {col})"
159        if sql and start is not None and end is not None:
160            formatted_sql = highlight_sql(sql, [(start, end)])[0]
161            error_msg += f"\n  {formatted_sql}"
162
163        raise OptimizeError(error_msg)
164
165    return expression
166
167
168def _separate_pseudocolumns(scope: Scope, pseudocolumns: t.Set[str]) -> None:
169    if not pseudocolumns:
170        return
171
172    has_pseudocolumns = False
173    scope_expression = scope.expression
174
175    for column in scope.columns:
176        name = column.name.upper()
177        if name not in pseudocolumns:
178            continue
179
180        if name != "LEVEL" or (
181            isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect")
182        ):
183            column.replace(exp.Pseudocolumn(**column.args))
184            has_pseudocolumns = True
185
186    if has_pseudocolumns:
187        scope.clear_cache()
188
189
190def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
191    name_columns = [
192        field.this
193        for field in unpivot.fields
194        if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
195    ]
196    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
197
198    return itertools.chain(name_columns, value_columns)
199
200
201def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
202    """
203    Remove table column aliases.
204
205    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
206    """
207    for derived_table in derived_tables:
208        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
209            continue
210        table_alias = derived_table.args.get("alias")
211        if table_alias:
212            table_alias.set("columns", None)
213
214
215def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
216    columns = {}
217
218    def _update_source_columns(source_name: str) -> None:
219        for column_name in resolver.get_source_columns(source_name):
220            if column_name not in columns:
221                columns[column_name] = source_name
222
223    joins = list(scope.find_all(exp.Join))
224    names = {join.alias_or_name for join in joins}
225    ordered = [key for key in scope.selected_sources if key not in names]
226
227    if names and not ordered:
228        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
229
230    # Mapping of automatically joined column names to an ordered set of source names (dict).
231    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
232
233    for source_name in ordered:
234        _update_source_columns(source_name)
235
236    for i, join in enumerate(joins):
237        source_table = ordered[-1]
238        if source_table:
239            _update_source_columns(source_table)
240
241        join_table = join.alias_or_name
242        ordered.append(join_table)
243
244        using = join.args.get("using")
245        if not using:
246            continue
247
248        join_columns = resolver.get_source_columns(join_table)
249        conditions = []
250        using_identifier_count = len(using)
251        is_semi_or_anti_join = join.is_semi_or_anti_join
252
253        for identifier in using:
254            identifier = identifier.name
255            table = columns.get(identifier)
256
257            if not table or identifier not in join_columns:
258                if (columns and "*" not in columns) and join_columns:
259                    raise OptimizeError(f"Cannot automatically join: {identifier}")
260
261            table = table or source_table
262
263            if i == 0 or using_identifier_count == 1:
264                lhs: exp.Expression = exp.column(identifier, table=table)
265            else:
266                coalesce_columns = [
267                    exp.column(identifier, table=t)
268                    for t in ordered[:-1]
269                    if identifier in resolver.get_source_columns(t)
270                ]
271                if len(coalesce_columns) > 1:
272                    lhs = exp.func("coalesce", *coalesce_columns)
273                else:
274                    lhs = exp.column(identifier, table=table)
275
276            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
277
278            # Set all values in the dict to None, because we only care about the key ordering
279            tables = column_tables.setdefault(identifier, {})
280
281            # Do not update the dict if this was a SEMI/ANTI join in
282            # order to avoid generating COALESCE columns for this join pair
283            if not is_semi_or_anti_join:
284                if table not in tables:
285                    tables[table] = None
286                if join_table not in tables:
287                    tables[join_table] = None
288
289        join.set("using", None)
290        join.set("on", exp.and_(*conditions, copy=False))
291
292    if column_tables:
293        for column in scope.columns:
294            if not column.table and column.name in column_tables:
295                tables = column_tables[column.name]
296                coalesce_args = [exp.column(column.name, table=table) for table in tables]
297                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
298
299                if isinstance(column.parent, exp.Select):
300                    # Ensure the USING column keeps its name if it's projected
301                    replacement = alias(replacement, alias=column.name, copy=False)
302                elif isinstance(column.parent, exp.Struct):
303                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
304                    replacement = exp.PropertyEQ(
305                        this=exp.to_identifier(column.name), expression=replacement
306                    )
307
308                scope.replace(column, replacement)
309
310    return column_tables
311
312
313def _expand_alias_refs(
314    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
315) -> None:
316    """
317    Expand references to aliases.
318    Example:
319        SELECT y.foo AS bar, bar * 2 AS baz FROM y
320     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
321    """
322    expression = scope.expression
323
324    if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION:
325        return
326
327    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
328    projections = {s.alias_or_name for s in expression.selects}
329    replaced = False
330
331    def replace_columns(
332        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
333    ) -> None:
334        nonlocal replaced
335        is_group_by = isinstance(node, exp.Group)
336        is_having = isinstance(node, exp.Having)
337        if not node or (expand_only_groupby and not is_group_by):
338            return
339
340        for column in walk_in_scope(node, prune=lambda node: node.is_star):
341            if not isinstance(column, exp.Column):
342                continue
343
344            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
345            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
346            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
347            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
348            if expand_only_groupby and is_group_by and column.parent is not node:
349                continue
350
351            skip_replace = False
352            table = resolver.get_table(column.name) if resolve_table and not column.table else None
353            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
354
355            if alias_expr:
356                skip_replace = bool(
357                    alias_expr.find(exp.AggFunc)
358                    and column.find_ancestor(exp.AggFunc)
359                    and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
360                )
361
362                # BigQuery's having clause gets confused if an alias matches a source.
363                # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1;
364                # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b)
365                # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed"
366                if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
367                    skip_replace = skip_replace or any(
368                        node.parts[0].name in projections
369                        for node in alias_expr.find_all(exp.Column)
370                    )
371            elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having):
372                column_table = table.name if table else column.table
373                if column_table in projections:
374                    # BigQuery's GROUP BY and HAVING clauses get confused if the column name
375                    # matches a source name and a projection. For instance:
376                    # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1
377                    # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table
378                    # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause"
379                    column.replace(exp.to_identifier(column.name))
380                    replaced = True
381                    return
382
383            if table and (not alias_expr or skip_replace):
384                column.set("table", table)
385            elif not column.table and alias_expr and not skip_replace:
386                if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and (
387                    literal_index or resolve_table
388                ):
389                    if literal_index:
390                        column.replace(exp.Literal.number(i))
391                        replaced = True
392                else:
393                    replaced = True
394                    column = column.replace(exp.paren(alias_expr))
395                    simplified = simplify_parens(column, dialect)
396                    if simplified is not column:
397                        column.replace(simplified)
398
399    for i, projection in enumerate(expression.selects):
400        replace_columns(projection)
401        if isinstance(projection, exp.Alias):
402            alias_to_expression[projection.alias] = (projection.this, i + 1)
403
404    parent_scope = scope
405    on_right_sub_tree = False
406    while parent_scope and not parent_scope.is_cte:
407        if parent_scope.is_union:
408            on_right_sub_tree = parent_scope.parent.expression.right is parent_scope.expression
409        parent_scope = parent_scope.parent
410
411    # We shouldn't expand aliases if they match the recursive CTE's columns
412    # and we are in the recursive part (right sub tree) of the CTE
413    if parent_scope and on_right_sub_tree:
414        cte = parent_scope.expression.parent
415        if cte.find_ancestor(exp.With).recursive:
416            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
417                alias_to_expression.pop(recursive_cte_column.output_name, None)
418
419    replace_columns(expression.args.get("where"))
420    replace_columns(expression.args.get("group"), literal_index=True)
421    replace_columns(expression.args.get("having"), resolve_table=True)
422    replace_columns(expression.args.get("qualify"), resolve_table=True)
423
424    if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS:
425        for join in expression.args.get("joins") or []:
426            replace_columns(join)
427
428    if replaced:
429        scope.clear_cache()
430
431
432def _expand_group_by(scope: Scope, dialect: Dialect) -> None:
433    expression = scope.expression
434    group = expression.args.get("group")
435    if not group:
436        return
437
438    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
439    expression.set("group", group)
440
441
442def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
443    for modifier_key in ("order", "distinct"):
444        modifier = scope.expression.args.get(modifier_key)
445        if isinstance(modifier, exp.Distinct):
446            modifier = modifier.args.get("on")
447
448        if not isinstance(modifier, exp.Expression):
449            continue
450
451        modifier_expressions = modifier.expressions
452        if modifier_key == "order":
453            modifier_expressions = [ordered.this for ordered in modifier_expressions]
454
455        for original, expanded in zip(
456            modifier_expressions,
457            _expand_positional_references(
458                scope, modifier_expressions, resolver.dialect, alias=True
459            ),
460        ):
461            for agg in original.find_all(exp.AggFunc):
462                for col in agg.find_all(exp.Column):
463                    if not col.table:
464                        col.set("table", resolver.get_table(col.name))
465
466            original.replace(expanded)
467
468        if scope.expression.args.get("group"):
469            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
470
471            for expression in modifier_expressions:
472                expression.replace(
473                    exp.to_identifier(_select_by_pos(scope, expression).alias)
474                    if expression.is_int
475                    else selects.get(expression, expression)
476                )
477
478
479def _expand_positional_references(
480    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: Dialect, alias: bool = False
481) -> t.List[exp.Expression]:
482    new_nodes: t.List[exp.Expression] = []
483    ambiguous_projections = None
484
485    for node in expressions:
486        if node.is_int:
487            select = _select_by_pos(scope, t.cast(exp.Literal, node))
488
489            if alias:
490                new_nodes.append(exp.column(select.args["alias"].copy()))
491            else:
492                select = select.this
493
494                if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES:
495                    if ambiguous_projections is None:
496                        # When a projection name is also a source name and it is referenced in the
497                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
498                        ambiguous_projections = {
499                            s.alias_or_name
500                            for s in scope.expression.selects
501                            if s.alias_or_name in scope.selected_sources
502                        }
503
504                    ambiguous = any(
505                        column.parts[0].name in ambiguous_projections
506                        for column in select.find_all(exp.Column)
507                    )
508                else:
509                    ambiguous = False
510
511                if (
512                    isinstance(select, exp.CONSTANTS)
513                    or select.is_number
514                    or select.find(exp.Explode, exp.Unnest)
515                    or ambiguous
516                ):
517                    new_nodes.append(node)
518                else:
519                    new_nodes.append(select.copy())
520        else:
521            new_nodes.append(node)
522
523    return new_nodes
524
525
526def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
527    try:
528        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
529    except IndexError:
530        raise OptimizeError(f"Unknown output column: {node.name}")
531
532
533def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
534    """
535    Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`.
536
537    These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be
538    normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly.
539    """
540    converted = False
541    for column in itertools.chain(scope.columns, scope.stars):
542        if isinstance(column, exp.Dot):
543            continue
544
545        column_table: t.Optional[str | exp.Identifier] = column.table
546        dot_parts = column.meta.pop("dot_parts", [])
547        if (
548            column_table
549            and column_table not in scope.sources
550            and (
551                not scope.parent
552                or column_table not in scope.parent.sources
553                or not scope.is_correlated_subquery
554            )
555        ):
556            root, *parts = column.parts
557
558            if root.name in scope.sources:
559                # The struct is already qualified, but we still need to change the AST
560                column_table = root
561                root, *parts = parts
562                was_qualified = True
563            else:
564                column_table = resolver.get_table(root.name)
565                was_qualified = False
566
567            if column_table:
568                converted = True
569                new_column = exp.column(root, table=column_table)
570
571                if dot_parts:
572                    # Remove the actual column parts from the rest of dot parts
573                    new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :]
574
575                column.replace(exp.Dot.build([new_column, *parts]))
576
577    if converted:
578        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
579        # a `for column in scope.columns` iteration, even though they shouldn't be
580        scope.clear_cache()
581
582
583def _qualify_columns(
584    scope: Scope,
585    resolver: Resolver,
586    allow_partial_qualification: bool,
587) -> None:
588    """Disambiguate columns, ensuring each column specifies a source"""
589    for column in scope.columns:
590        column_table = column.table
591        column_name = column.name
592
593        if column_table and column_table in scope.sources:
594            source_columns = resolver.get_source_columns(column_table)
595            if (
596                not allow_partial_qualification
597                and source_columns
598                and column_name not in source_columns
599                and "*" not in source_columns
600            ):
601                raise OptimizeError(f"Unknown column: {column_name}")
602
603        if not column_table:
604            if scope.pivots and not column.find_ancestor(exp.Pivot):
605                # If the column is under the Pivot expression, we need to qualify it
606                # using the name of the pivoted source instead of the pivot's alias
607                column.set("table", exp.to_identifier(scope.pivots[0].alias))
608                continue
609
610            # column_table can be a '' because bigquery unnest has no table alias
611            column_table = resolver.get_table(column)
612
613            if (
614                column_table
615                and isinstance(source := scope.sources.get(column_table.name), Scope)
616                and id(column) in source.column_index
617            ):
618                continue
619
620            if column_table:
621                column.set("table", column_table)
622            elif (
623                resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS
624                and len(column.parts) == 1
625                and column_name in scope.selected_sources
626            ):
627                # BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records
628                scope.replace(column, exp.TableColumn(this=column.this))
629
630    for pivot in scope.pivots:
631        for column in pivot.find_all(exp.Column):
632            if not column.table and column.name in resolver.all_columns:
633                column_table = resolver.get_table(column.name)
634                if column_table:
635                    column.set("table", column_table)
636
637
638def _expand_struct_stars_no_parens(
639    expression: exp.Dot,
640) -> t.List[exp.Alias]:
641    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
642
643    dot_column = expression.find(exp.Column)
644    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
645        return []
646
647    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
648    dot_column = dot_column.copy()
649    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
650
651    # First part is the table name and last part is the star so they can be dropped
652    dot_parts = expression.parts[1:-1]
653
654    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
655    for part in dot_parts[1:]:
656        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
657            # Unable to expand star unless all fields are named
658            if not isinstance(field.this, exp.Identifier):
659                return []
660
661            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
662                starting_struct = field
663                break
664        else:
665            # There is no matching field in the struct
666            return []
667
668    taken_names = set()
669    new_selections = []
670
671    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
672        name = field.name
673
674        # Ambiguous or anonymous fields can't be expanded
675        if name in taken_names or not isinstance(field.this, exp.Identifier):
676            return []
677
678        taken_names.add(name)
679
680        this = field.this.copy()
681        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
682        new_column = exp.column(
683            t.cast(exp.Identifier, root),
684            table=dot_column.args.get("table"),
685            fields=t.cast(t.List[exp.Identifier], parts),
686        )
687        new_selections.append(alias(new_column, this, copy=False))
688
689    return new_selections
690
691
692def _expand_struct_stars_with_parens(expression: exp.Dot) -> t.List[exp.Alias]:
693    """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column"""
694
695    # it is not (<sub_exp>).* pattern, which means we can't expand
696    if not isinstance(expression.this, exp.Paren):
697        return []
698
699    # find column definition to get data-type
700    dot_column = expression.find(exp.Column)
701    if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT):
702        return []
703
704    parent = dot_column.parent
705    starting_struct = dot_column.type
706
707    # walk up AST and down into struct definition in sync
708    while parent is not None:
709        if isinstance(parent, exp.Paren):
710            parent = parent.parent
711            continue
712
713        # if parent is not a dot, then something is wrong
714        if not isinstance(parent, exp.Dot):
715            return []
716
717        # if the rhs of the dot is star we are done
718        rhs = parent.right
719        if isinstance(rhs, exp.Star):
720            break
721
722        # if it is not identifier, then something is wrong
723        if not isinstance(rhs, exp.Identifier):
724            return []
725
726        # Check if current rhs identifier is in struct
727        matched = False
728        for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
729            if struct_field_def.name == rhs.name:
730                matched = True
731                starting_struct = struct_field_def.kind  # update struct
732                break
733
734        if not matched:
735            return []
736
737        parent = parent.parent
738
739    # build new aliases to expand star
740    new_selections = []
741
742    # fetch the outermost parentheses for new aliaes
743    outer_paren = expression.this
744
745    for struct_field_def in t.cast(exp.DataType, starting_struct).expressions:
746        new_identifier = struct_field_def.this.copy()
747        new_dot = exp.Dot.build([outer_paren.copy(), new_identifier])
748        new_alias = alias(new_dot, new_identifier, copy=False)
749        new_selections.append(new_alias)
750
751    return new_selections
752
753
754def _expand_stars(
755    scope: Scope,
756    resolver: Resolver,
757    using_column_tables: t.Dict[str, t.Any],
758    pseudocolumns: t.Set[str],
759    annotator: TypeAnnotator,
760) -> None:
761    """Expand stars to lists of column selections"""
762
763    new_selections: t.List[exp.Expression] = []
764    except_columns: t.Dict[int, t.Set[str]] = {}
765    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
766    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
767
768    coalesced_columns = set()
769    dialect = resolver.dialect
770
771    pivot_output_columns = None
772    pivot_exclude_columns: t.Set[str] = set()
773
774    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
775    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
776        if pivot.unpivot:
777            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
778
779            for field in pivot.fields:
780                if isinstance(field, exp.In):
781                    pivot_exclude_columns.update(
782                        c.output_name for e in field.expressions for c in e.find_all(exp.Column)
783                    )
784
785        else:
786            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
787
788            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
789            if not pivot_output_columns:
790                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
791
792    if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any(
793        isinstance(col, exp.Dot) for col in scope.stars
794    ):
795        # Found struct expansion, annotate scope ahead of time
796        annotator.annotate_scope(scope)
797
798    for expression in scope.expression.selects:
799        tables = []
800        if isinstance(expression, exp.Star):
801            tables.extend(scope.selected_sources)
802            _add_except_columns(expression, tables, except_columns)
803            _add_replace_columns(expression, tables, replace_columns)
804            _add_rename_columns(expression, tables, rename_columns)
805        elif expression.is_star:
806            if not isinstance(expression, exp.Dot):
807                tables.append(expression.table)
808                _add_except_columns(expression.this, tables, except_columns)
809                _add_replace_columns(expression.this, tables, replace_columns)
810                _add_rename_columns(expression.this, tables, rename_columns)
811            elif (
812                dialect.SUPPORTS_STRUCT_STAR_EXPANSION
813                and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS
814            ):
815                struct_fields = _expand_struct_stars_no_parens(expression)
816                if struct_fields:
817                    new_selections.extend(struct_fields)
818                    continue
819            elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS:
820                struct_fields = _expand_struct_stars_with_parens(expression)
821                if struct_fields:
822                    new_selections.extend(struct_fields)
823                    continue
824
825        if not tables:
826            new_selections.append(expression)
827            continue
828
829        for table in tables:
830            if table not in scope.sources:
831                raise OptimizeError(f"Unknown table: {table}")
832
833            columns = resolver.get_source_columns(table, only_visible=True)
834            columns = columns or scope.outer_columns
835
836            if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR:
837                columns = [name for name in columns if name.upper() not in pseudocolumns]
838
839            if not columns or "*" in columns:
840                return
841
842            table_id = id(table)
843            columns_to_exclude = except_columns.get(table_id) or set()
844            renamed_columns = rename_columns.get(table_id, {})
845            replaced_columns = replace_columns.get(table_id, {})
846
847            if pivot:
848                if pivot_output_columns and pivot_exclude_columns:
849                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
850                    pivot_columns.extend(pivot_output_columns)
851                else:
852                    pivot_columns = pivot.alias_column_names
853
854                if pivot_columns:
855                    new_selections.extend(
856                        alias(exp.column(name, table=pivot.alias), name, copy=False)
857                        for name in pivot_columns
858                        if name not in columns_to_exclude
859                    )
860                    continue
861
862            for name in columns:
863                if name in columns_to_exclude or name in coalesced_columns:
864                    continue
865                if name in using_column_tables and table in using_column_tables[name]:
866                    coalesced_columns.add(name)
867                    tables = using_column_tables[name]
868                    coalesce_args = [exp.column(name, table=table) for table in tables]
869
870                    new_selections.append(
871                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
872                    )
873                else:
874                    alias_ = renamed_columns.get(name, name)
875                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
876                    new_selections.append(
877                        alias(selection_expr, alias_, copy=False)
878                        if alias_ != name
879                        else selection_expr
880                    )
881
882    # Ensures we don't overwrite the initial selections with an empty list
883    if new_selections and isinstance(scope.expression, exp.Select):
884        scope.expression.set("expressions", new_selections)
885
886
887def _add_except_columns(
888    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
889) -> None:
890    except_ = expression.args.get("except_")
891
892    if not except_:
893        return
894
895    columns = {e.name for e in except_}
896
897    for table in tables:
898        except_columns[id(table)] = columns
899
900
901def _add_rename_columns(
902    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
903) -> None:
904    rename = expression.args.get("rename")
905
906    if not rename:
907        return
908
909    columns = {e.this.name: e.alias for e in rename}
910
911    for table in tables:
912        rename_columns[id(table)] = columns
913
914
915def _add_replace_columns(
916    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
917) -> None:
918    replace = expression.args.get("replace")
919
920    if not replace:
921        return
922
923    columns = {e.alias: e for e in replace}
924
925    for table in tables:
926        replace_columns[id(table)] = columns
927
928
929def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
930    """Ensure all output columns are aliased"""
931    if isinstance(scope_or_expression, exp.Expression):
932        scope = build_scope(scope_or_expression)
933        if not isinstance(scope, Scope):
934            return
935    else:
936        scope = scope_or_expression
937
938    new_selections = []
939    for i, (selection, aliased_column) in enumerate(
940        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
941    ):
942        if selection is None or isinstance(selection, exp.QueryTransform):
943            break
944
945        if isinstance(selection, exp.Subquery):
946            if not selection.output_name:
947                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
948        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
949            selection = alias(
950                selection,
951                alias=selection.output_name or f"_col_{i}",
952                copy=False,
953            )
954        if aliased_column:
955            selection.set("alias", exp.to_identifier(aliased_column))
956
957        new_selections.append(selection)
958
959    if new_selections and isinstance(scope.expression, exp.Select):
960        scope.expression.set("expressions", new_selections)
961
962
963def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
964    """Makes sure all identifiers that need to be quoted are quoted."""
965    return expression.transform(
966        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
967    )  # type: ignore
968
969
970def pushdown_cte_alias_columns(scope: Scope) -> None:
971    """
972    Pushes down the CTE alias columns into the projection,
973
974    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
975
976    Args:
977        scope: Scope to find ctes to pushdown aliases.
978    """
979    for cte in scope.ctes:
980        if cte.alias_column_names and isinstance(cte.this, exp.Select):
981            new_expressions = []
982            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
983                if isinstance(projection, exp.Alias):
984                    projection.set("alias", exp.to_identifier(_alias))
985                else:
986                    projection = alias(projection, alias=_alias)
987                new_expressions.append(projection)
988            cte.this.set("expressions", new_expressions)
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
 21def qualify_columns(
 22    expression: exp.Expression,
 23    schema: t.Dict | Schema,
 24    expand_alias_refs: bool = True,
 25    expand_stars: bool = True,
 26    infer_schema: t.Optional[bool] = None,
 27    allow_partial_qualification: bool = False,
 28    dialect: DialectType = None,
 29) -> exp.Expression:
 30    """
 31    Rewrite sqlglot AST to have fully qualified columns.
 32
 33    Example:
 34        >>> import sqlglot
 35        >>> schema = {"tbl": {"col": "INT"}}
 36        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 37        >>> qualify_columns(expression, schema).sql()
 38        'SELECT tbl.col AS col FROM tbl'
 39
 40    Args:
 41        expression: Expression to qualify.
 42        schema: Database schema.
 43        expand_alias_refs: Whether to expand references to aliases.
 44        expand_stars: Whether to expand star queries. This is a necessary step
 45            for most of the optimizer's rules to work; do not set to False unless you
 46            know what you're doing!
 47        infer_schema: Whether to infer the schema if missing.
 48        allow_partial_qualification: Whether to allow partial qualification.
 49
 50    Returns:
 51        The qualified expression.
 52
 53    Notes:
 54        - Currently only handles a single PIVOT or UNPIVOT operator
 55    """
 56    schema = ensure_schema(schema, dialect=dialect)
 57    annotator = TypeAnnotator(schema)
 58    infer_schema = schema.empty if infer_schema is None else infer_schema
 59    dialect = schema.dialect or Dialect()
 60    pseudocolumns = dialect.PSEUDOCOLUMNS
 61
 62    for scope in traverse_scope(expression):
 63        if dialect.PREFER_CTE_ALIAS_COLUMN:
 64            pushdown_cte_alias_columns(scope)
 65
 66        scope_expression = scope.expression
 67        is_select = isinstance(scope_expression, exp.Select)
 68
 69        _separate_pseudocolumns(scope, pseudocolumns)
 70
 71        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 72        _pop_table_column_aliases(scope.ctes)
 73        _pop_table_column_aliases(scope.derived_tables)
 74        using_column_tables = _expand_using(scope, resolver)
 75
 76        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 77            _expand_alias_refs(
 78                scope,
 79                resolver,
 80                dialect,
 81                expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF,
 82            )
 83
 84        _convert_columns_to_dots(scope, resolver)
 85        _qualify_columns(
 86            scope,
 87            resolver,
 88            allow_partial_qualification=allow_partial_qualification,
 89        )
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106
107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
109        _expand_order_by_and_distinct_on(scope, resolver)
110
111        if dialect.ANNOTATE_ALL_SCOPES:
112            annotator.annotate_scope(scope)
113
114    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E, sql: Optional[str] = None) -> ~E:
117def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E:
118    """Raise an `OptimizeError` if any columns aren't qualified"""
119    all_unqualified_columns = []
120    for scope in traverse_scope(expression):
121        if isinstance(scope.expression, exp.Select):
122            unqualified_columns = scope.unqualified_columns
123
124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125                column = scope.external_columns[0]
126                for_table = f" for table: '{column.table}'" if column.table else ""
127                line = column.this.meta.get("line")
128                col = column.this.meta.get("col")
129                start = column.this.meta.get("start")
130                end = column.this.meta.get("end")
131
132                error_msg = f"Column '{column.name}' could not be resolved{for_table}."
133                if line and col:
134                    error_msg += f" Line: {line}, Col: {col}"
135                if sql and start is not None and end is not None:
136                    formatted_sql = highlight_sql(sql, [(start, end)])[0]
137                    error_msg += f"\n  {formatted_sql}"
138
139                raise OptimizeError(error_msg)
140
141            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
142                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
143                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
144                # this list here to ensure those in the former category will be excluded.
145                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
146                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
147
148            all_unqualified_columns.extend(unqualified_columns)
149
150    if all_unqualified_columns:
151        first_column = all_unqualified_columns[0]
152        line = first_column.this.meta.get("line")
153        col = first_column.this.meta.get("col")
154        start = first_column.this.meta.get("start")
155        end = first_column.this.meta.get("end")
156
157        error_msg = f"Ambiguous column '{first_column.name}'"
158        if line and col:
159            error_msg += f" (Line: {line}, Col: {col})"
160        if sql and start is not None and end is not None:
161            formatted_sql = highlight_sql(sql, [(start, end)])[0]
162            error_msg += f"\n  {formatted_sql}"
163
164        raise OptimizeError(error_msg)
165
166    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
930def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
931    """Ensure all output columns are aliased"""
932    if isinstance(scope_or_expression, exp.Expression):
933        scope = build_scope(scope_or_expression)
934        if not isinstance(scope, Scope):
935            return
936    else:
937        scope = scope_or_expression
938
939    new_selections = []
940    for i, (selection, aliased_column) in enumerate(
941        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
942    ):
943        if selection is None or isinstance(selection, exp.QueryTransform):
944            break
945
946        if isinstance(selection, exp.Subquery):
947            if not selection.output_name:
948                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
949        elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star:
950            selection = alias(
951                selection,
952                alias=selection.output_name or f"_col_{i}",
953                copy=False,
954            )
955        if aliased_column:
956            selection.set("alias", exp.to_identifier(aliased_column))
957
958        new_selections.append(selection)
959
960    if new_selections and isinstance(scope.expression, exp.Select):
961        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
964def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
965    """Makes sure all identifiers that need to be quoted are quoted."""
966    return expression.transform(
967        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
968    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns(scope: sqlglot.optimizer.scope.Scope) -> None:
971def pushdown_cte_alias_columns(scope: Scope) -> None:
972    """
973    Pushes down the CTE alias columns into the projection,
974
975    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
976
977    Args:
978        scope: Scope to find ctes to pushdown aliases.
979    """
980    for cte in scope.ctes:
981        if cte.alias_column_names and isinstance(cte.this, exp.Select):
982            new_expressions = []
983            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
984                if isinstance(projection, exp.Alias):
985                    projection.set("alias", exp.to_identifier(_alias))
986                else:
987                    projection = alias(projection, alias=_alias)
988                new_expressions.append(projection)
989            cte.this.set("expressions", new_expressions)

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Arguments:
  • scope: Scope to find ctes to pushdown aliases.