Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25    allow_partial_qualification: bool = False,
 26) -> exp.Expression:
 27    """
 28    Rewrite sqlglot AST to have fully qualified columns.
 29
 30    Example:
 31        >>> import sqlglot
 32        >>> schema = {"tbl": {"col": "INT"}}
 33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 34        >>> qualify_columns(expression, schema).sql()
 35        'SELECT tbl.col AS col FROM tbl'
 36
 37    Args:
 38        expression: Expression to qualify.
 39        schema: Database schema.
 40        expand_alias_refs: Whether to expand references to aliases.
 41        expand_stars: Whether to expand star queries. This is a necessary step
 42            for most of the optimizer's rules to work; do not set to False unless you
 43            know what you're doing!
 44        infer_schema: Whether to infer the schema if missing.
 45        allow_partial_qualification: Whether to allow partial qualification.
 46
 47    Returns:
 48        The qualified expression.
 49
 50    Notes:
 51        - Currently only handles a single PIVOT or UNPIVOT operator
 52    """
 53    schema = ensure_schema(schema)
 54    annotator = TypeAnnotator(schema)
 55    infer_schema = schema.empty if infer_schema is None else infer_schema
 56    dialect = Dialect.get_or_raise(schema.dialect)
 57    pseudocolumns = dialect.PSEUDOCOLUMNS
 58
 59    snowflake_or_oracle = dialect in ("oracle", "snowflake")
 60
 61    for scope in traverse_scope(expression):
 62        scope_expression = scope.expression
 63        is_select = isinstance(scope_expression, exp.Select)
 64
 65        if is_select and snowflake_or_oracle and scope_expression.args.get("connect"):
 66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 68            scope_expression.transform(
 69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 70                copy=False,
 71            )
 72            scope.clear_cache()
 73
 74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 75        _pop_table_column_aliases(scope.ctes)
 76        _pop_table_column_aliases(scope.derived_tables)
 77        using_column_tables = _expand_using(scope, resolver)
 78
 79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 80            _expand_alias_refs(
 81                scope,
 82                resolver,
 83                dialect,
 84                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 85            )
 86
 87        _convert_columns_to_dots(scope, resolver)
 88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 89
 90        if not schema.empty and expand_alias_refs:
 91            _expand_alias_refs(scope, resolver, dialect)
 92
 93        if is_select:
 94            if expand_stars:
 95                _expand_stars(
 96                    scope,
 97                    resolver,
 98                    using_column_tables,
 99                    pseudocolumns,
100                    annotator,
101                )
102            qualify_outputs(scope)
103
104        _expand_group_by(scope, dialect)
105
106        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
107        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
108        _expand_order_by_and_distinct_on(scope, resolver)
109
110        if dialect == "bigquery":
111            annotator.annotate_scope(scope)
112
113    return expression
114
115
116def validate_qualify_columns(expression: E) -> E:
117    """Raise an `OptimizeError` if any columns aren't qualified"""
118    all_unqualified_columns = []
119    for scope in traverse_scope(expression):
120        if isinstance(scope.expression, exp.Select):
121            unqualified_columns = scope.unqualified_columns
122
123            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
124                column = scope.external_columns[0]
125                for_table = f" for table: '{column.table}'" if column.table else ""
126                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
127
128            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
129                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
130                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
131                # this list here to ensure those in the former category will be excluded.
132                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
133                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
134
135            all_unqualified_columns.extend(unqualified_columns)
136
137    if all_unqualified_columns:
138        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
139
140    return expression
141
142
143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
144    name_column = []
145    field = unpivot.args.get("field")
146    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
147        name_column.append(field.this)
148
149    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
150    return itertools.chain(name_column, value_columns)
151
152
153def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
154    """
155    Remove table column aliases.
156
157    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
158    """
159    for derived_table in derived_tables:
160        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
161            continue
162        table_alias = derived_table.args.get("alias")
163        if table_alias:
164            table_alias.args.pop("columns", None)
165
166
167def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
168    columns = {}
169
170    def _update_source_columns(source_name: str) -> None:
171        for column_name in resolver.get_source_columns(source_name):
172            if column_name not in columns:
173                columns[column_name] = source_name
174
175    joins = list(scope.find_all(exp.Join))
176    names = {join.alias_or_name for join in joins}
177    ordered = [key for key in scope.selected_sources if key not in names]
178
179    if names and not ordered:
180        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
181
182    # Mapping of automatically joined column names to an ordered set of source names (dict).
183    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
184
185    for source_name in ordered:
186        _update_source_columns(source_name)
187
188    for i, join in enumerate(joins):
189        source_table = ordered[-1]
190        if source_table:
191            _update_source_columns(source_table)
192
193        join_table = join.alias_or_name
194        ordered.append(join_table)
195
196        using = join.args.get("using")
197        if not using:
198            continue
199
200        join_columns = resolver.get_source_columns(join_table)
201        conditions = []
202        using_identifier_count = len(using)
203        is_semi_or_anti_join = join.is_semi_or_anti_join
204
205        for identifier in using:
206            identifier = identifier.name
207            table = columns.get(identifier)
208
209            if not table or identifier not in join_columns:
210                if (columns and "*" not in columns) and join_columns:
211                    raise OptimizeError(f"Cannot automatically join: {identifier}")
212
213            table = table or source_table
214
215            if i == 0 or using_identifier_count == 1:
216                lhs: exp.Expression = exp.column(identifier, table=table)
217            else:
218                coalesce_columns = [
219                    exp.column(identifier, table=t)
220                    for t in ordered[:-1]
221                    if identifier in resolver.get_source_columns(t)
222                ]
223                if len(coalesce_columns) > 1:
224                    lhs = exp.func("coalesce", *coalesce_columns)
225                else:
226                    lhs = exp.column(identifier, table=table)
227
228            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
229
230            # Set all values in the dict to None, because we only care about the key ordering
231            tables = column_tables.setdefault(identifier, {})
232
233            # Do not update the dict if this was a SEMI/ANTI join in
234            # order to avoid generating COALESCE columns for this join pair
235            if not is_semi_or_anti_join:
236                if table not in tables:
237                    tables[table] = None
238                if join_table not in tables:
239                    tables[join_table] = None
240
241        join.args.pop("using")
242        join.set("on", exp.and_(*conditions, copy=False))
243
244    if column_tables:
245        for column in scope.columns:
246            if not column.table and column.name in column_tables:
247                tables = column_tables[column.name]
248                coalesce_args = [exp.column(column.name, table=table) for table in tables]
249                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
250
251                if isinstance(column.parent, exp.Select):
252                    # Ensure the USING column keeps its name if it's projected
253                    replacement = alias(replacement, alias=column.name, copy=False)
254                elif isinstance(column.parent, exp.Struct):
255                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
256                    replacement = exp.PropertyEQ(
257                        this=exp.to_identifier(column.name), expression=replacement
258                    )
259
260                scope.replace(column, replacement)
261
262    return column_tables
263
264
265def _expand_alias_refs(
266    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
267) -> None:
268    """
269    Expand references to aliases.
270    Example:
271        SELECT y.foo AS bar, bar * 2 AS baz FROM y
272     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
273    """
274    expression = scope.expression
275
276    if not isinstance(expression, exp.Select):
277        return
278
279    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
280
281    def replace_columns(
282        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
283    ) -> None:
284        is_group_by = isinstance(node, exp.Group)
285        if not node or (expand_only_groupby and not is_group_by):
286            return
287
288        for column in walk_in_scope(node, prune=lambda node: node.is_star):
289            if not isinstance(column, exp.Column):
290                continue
291
292            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
293            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
294            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
295            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
296            if expand_only_groupby and is_group_by and column.parent is not node:
297                continue
298
299            table = resolver.get_table(column.name) if resolve_table and not column.table else None
300            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
301            double_agg = (
302                (
303                    alias_expr.find(exp.AggFunc)
304                    and (
305                        column.find_ancestor(exp.AggFunc)
306                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
307                    )
308                )
309                if alias_expr
310                else False
311            )
312
313            if table and (not alias_expr or double_agg):
314                column.set("table", table)
315            elif not column.table and alias_expr and not double_agg:
316                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
317                    if literal_index:
318                        column.replace(exp.Literal.number(i))
319                else:
320                    column = column.replace(exp.paren(alias_expr))
321                    simplified = simplify_parens(column)
322                    if simplified is not column:
323                        column.replace(simplified)
324
325    for i, projection in enumerate(expression.selects):
326        replace_columns(projection)
327        if isinstance(projection, exp.Alias):
328            alias_to_expression[projection.alias] = (projection.this, i + 1)
329
330    parent_scope = scope
331    while parent_scope.is_union:
332        parent_scope = parent_scope.parent
333
334    # We shouldn't expand aliases if they match the recursive CTE's columns
335    if parent_scope.is_cte:
336        cte = parent_scope.expression.parent
337        if cte.find_ancestor(exp.With).recursive:
338            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
339                alias_to_expression.pop(recursive_cte_column.output_name, None)
340
341    replace_columns(expression.args.get("where"))
342    replace_columns(expression.args.get("group"), literal_index=True)
343    replace_columns(expression.args.get("having"), resolve_table=True)
344    replace_columns(expression.args.get("qualify"), resolve_table=True)
345
346    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
347    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
348    if dialect == "snowflake":
349        for join in expression.args.get("joins") or []:
350            replace_columns(join)
351
352    scope.clear_cache()
353
354
355def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
356    expression = scope.expression
357    group = expression.args.get("group")
358    if not group:
359        return
360
361    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
362    expression.set("group", group)
363
364
365def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None:
366    for modifier_key in ("order", "distinct"):
367        modifier = scope.expression.args.get(modifier_key)
368        if isinstance(modifier, exp.Distinct):
369            modifier = modifier.args.get("on")
370
371        if not isinstance(modifier, exp.Expression):
372            continue
373
374        modifier_expressions = modifier.expressions
375        if modifier_key == "order":
376            modifier_expressions = [ordered.this for ordered in modifier_expressions]
377
378        for original, expanded in zip(
379            modifier_expressions,
380            _expand_positional_references(
381                scope, modifier_expressions, resolver.schema.dialect, alias=True
382            ),
383        ):
384            for agg in original.find_all(exp.AggFunc):
385                for col in agg.find_all(exp.Column):
386                    if not col.table:
387                        col.set("table", resolver.get_table(col.name))
388
389            original.replace(expanded)
390
391        if scope.expression.args.get("group"):
392            selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
393
394            for expression in modifier_expressions:
395                expression.replace(
396                    exp.to_identifier(_select_by_pos(scope, expression).alias)
397                    if expression.is_int
398                    else selects.get(expression, expression)
399                )
400
401
402def _expand_positional_references(
403    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
404) -> t.List[exp.Expression]:
405    new_nodes: t.List[exp.Expression] = []
406    ambiguous_projections = None
407
408    for node in expressions:
409        if node.is_int:
410            select = _select_by_pos(scope, t.cast(exp.Literal, node))
411
412            if alias:
413                new_nodes.append(exp.column(select.args["alias"].copy()))
414            else:
415                select = select.this
416
417                if dialect == "bigquery":
418                    if ambiguous_projections is None:
419                        # When a projection name is also a source name and it is referenced in the
420                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
421                        ambiguous_projections = {
422                            s.alias_or_name
423                            for s in scope.expression.selects
424                            if s.alias_or_name in scope.selected_sources
425                        }
426
427                    ambiguous = any(
428                        column.parts[0].name in ambiguous_projections
429                        for column in select.find_all(exp.Column)
430                    )
431                else:
432                    ambiguous = False
433
434                if (
435                    isinstance(select, exp.CONSTANTS)
436                    or select.find(exp.Explode, exp.Unnest)
437                    or ambiguous
438                ):
439                    new_nodes.append(node)
440                else:
441                    new_nodes.append(select.copy())
442        else:
443            new_nodes.append(node)
444
445    return new_nodes
446
447
448def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
449    try:
450        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
451    except IndexError:
452        raise OptimizeError(f"Unknown output column: {node.name}")
453
454
455def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
456    """
457    Converts `Column` instances that represent struct field lookup into chained `Dots`.
458
459    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
460    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
461    """
462    converted = False
463    for column in itertools.chain(scope.columns, scope.stars):
464        if isinstance(column, exp.Dot):
465            continue
466
467        column_table: t.Optional[str | exp.Identifier] = column.table
468        if (
469            column_table
470            and column_table not in scope.sources
471            and (
472                not scope.parent
473                or column_table not in scope.parent.sources
474                or not scope.is_correlated_subquery
475            )
476        ):
477            root, *parts = column.parts
478
479            if root.name in scope.sources:
480                # The struct is already qualified, but we still need to change the AST
481                column_table = root
482                root, *parts = parts
483            else:
484                column_table = resolver.get_table(root.name)
485
486            if column_table:
487                converted = True
488                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
489
490    if converted:
491        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
492        # a `for column in scope.columns` iteration, even though they shouldn't be
493        scope.clear_cache()
494
495
496def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
497    """Disambiguate columns, ensuring each column specifies a source"""
498    for column in scope.columns:
499        column_table = column.table
500        column_name = column.name
501
502        if column_table and column_table in scope.sources:
503            source_columns = resolver.get_source_columns(column_table)
504            if (
505                not allow_partial_qualification
506                and source_columns
507                and column_name not in source_columns
508                and "*" not in source_columns
509            ):
510                raise OptimizeError(f"Unknown column: {column_name}")
511
512        if not column_table:
513            if scope.pivots and not column.find_ancestor(exp.Pivot):
514                # If the column is under the Pivot expression, we need to qualify it
515                # using the name of the pivoted source instead of the pivot's alias
516                column.set("table", exp.to_identifier(scope.pivots[0].alias))
517                continue
518
519            # column_table can be a '' because bigquery unnest has no table alias
520            column_table = resolver.get_table(column_name)
521            if column_table:
522                column.set("table", column_table)
523
524    for pivot in scope.pivots:
525        for column in pivot.find_all(exp.Column):
526            if not column.table and column.name in resolver.all_columns:
527                column_table = resolver.get_table(column.name)
528                if column_table:
529                    column.set("table", column_table)
530
531
532def _expand_struct_stars(
533    expression: exp.Dot,
534) -> t.List[exp.Alias]:
535    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
536
537    dot_column = t.cast(exp.Column, expression.find(exp.Column))
538    if not dot_column.is_type(exp.DataType.Type.STRUCT):
539        return []
540
541    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
542    dot_column = dot_column.copy()
543    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
544
545    # First part is the table name and last part is the star so they can be dropped
546    dot_parts = expression.parts[1:-1]
547
548    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
549    for part in dot_parts[1:]:
550        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
551            # Unable to expand star unless all fields are named
552            if not isinstance(field.this, exp.Identifier):
553                return []
554
555            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
556                starting_struct = field
557                break
558        else:
559            # There is no matching field in the struct
560            return []
561
562    taken_names = set()
563    new_selections = []
564
565    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
566        name = field.name
567
568        # Ambiguous or anonymous fields can't be expanded
569        if name in taken_names or not isinstance(field.this, exp.Identifier):
570            return []
571
572        taken_names.add(name)
573
574        this = field.this.copy()
575        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
576        new_column = exp.column(
577            t.cast(exp.Identifier, root),
578            table=dot_column.args.get("table"),
579            fields=t.cast(t.List[exp.Identifier], parts),
580        )
581        new_selections.append(alias(new_column, this, copy=False))
582
583    return new_selections
584
585
586def _expand_stars(
587    scope: Scope,
588    resolver: Resolver,
589    using_column_tables: t.Dict[str, t.Any],
590    pseudocolumns: t.Set[str],
591    annotator: TypeAnnotator,
592) -> None:
593    """Expand stars to lists of column selections"""
594
595    new_selections: t.List[exp.Expression] = []
596    except_columns: t.Dict[int, t.Set[str]] = {}
597    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
598    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
599
600    coalesced_columns = set()
601    dialect = resolver.schema.dialect
602
603    pivot_output_columns = None
604    pivot_exclude_columns = None
605
606    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
607    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
608        if pivot.unpivot:
609            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
610
611            field = pivot.args.get("field")
612            if isinstance(field, exp.In):
613                pivot_exclude_columns = {
614                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
615                }
616        else:
617            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
618
619            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
620            if not pivot_output_columns:
621                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
622
623    is_bigquery = dialect == "bigquery"
624    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
625        # Found struct expansion, annotate scope ahead of time
626        annotator.annotate_scope(scope)
627
628    for expression in scope.expression.selects:
629        tables = []
630        if isinstance(expression, exp.Star):
631            tables.extend(scope.selected_sources)
632            _add_except_columns(expression, tables, except_columns)
633            _add_replace_columns(expression, tables, replace_columns)
634            _add_rename_columns(expression, tables, rename_columns)
635        elif expression.is_star:
636            if not isinstance(expression, exp.Dot):
637                tables.append(expression.table)
638                _add_except_columns(expression.this, tables, except_columns)
639                _add_replace_columns(expression.this, tables, replace_columns)
640                _add_rename_columns(expression.this, tables, rename_columns)
641            elif is_bigquery:
642                struct_fields = _expand_struct_stars(expression)
643                if struct_fields:
644                    new_selections.extend(struct_fields)
645                    continue
646
647        if not tables:
648            new_selections.append(expression)
649            continue
650
651        for table in tables:
652            if table not in scope.sources:
653                raise OptimizeError(f"Unknown table: {table}")
654
655            columns = resolver.get_source_columns(table, only_visible=True)
656            columns = columns or scope.outer_columns
657
658            if pseudocolumns:
659                columns = [name for name in columns if name.upper() not in pseudocolumns]
660
661            if not columns or "*" in columns:
662                return
663
664            table_id = id(table)
665            columns_to_exclude = except_columns.get(table_id) or set()
666            renamed_columns = rename_columns.get(table_id, {})
667            replaced_columns = replace_columns.get(table_id, {})
668
669            if pivot:
670                if pivot_output_columns and pivot_exclude_columns:
671                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
672                    pivot_columns.extend(pivot_output_columns)
673                else:
674                    pivot_columns = pivot.alias_column_names
675
676                if pivot_columns:
677                    new_selections.extend(
678                        alias(exp.column(name, table=pivot.alias), name, copy=False)
679                        for name in pivot_columns
680                        if name not in columns_to_exclude
681                    )
682                    continue
683
684            for name in columns:
685                if name in columns_to_exclude or name in coalesced_columns:
686                    continue
687                if name in using_column_tables and table in using_column_tables[name]:
688                    coalesced_columns.add(name)
689                    tables = using_column_tables[name]
690                    coalesce_args = [exp.column(name, table=table) for table in tables]
691
692                    new_selections.append(
693                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
694                    )
695                else:
696                    alias_ = renamed_columns.get(name, name)
697                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
698                    new_selections.append(
699                        alias(selection_expr, alias_, copy=False)
700                        if alias_ != name
701                        else selection_expr
702                    )
703
704    # Ensures we don't overwrite the initial selections with an empty list
705    if new_selections and isinstance(scope.expression, exp.Select):
706        scope.expression.set("expressions", new_selections)
707
708
709def _add_except_columns(
710    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
711) -> None:
712    except_ = expression.args.get("except")
713
714    if not except_:
715        return
716
717    columns = {e.name for e in except_}
718
719    for table in tables:
720        except_columns[id(table)] = columns
721
722
723def _add_rename_columns(
724    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
725) -> None:
726    rename = expression.args.get("rename")
727
728    if not rename:
729        return
730
731    columns = {e.this.name: e.alias for e in rename}
732
733    for table in tables:
734        rename_columns[id(table)] = columns
735
736
737def _add_replace_columns(
738    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
739) -> None:
740    replace = expression.args.get("replace")
741
742    if not replace:
743        return
744
745    columns = {e.alias: e for e in replace}
746
747    for table in tables:
748        replace_columns[id(table)] = columns
749
750
751def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
752    """Ensure all output columns are aliased"""
753    if isinstance(scope_or_expression, exp.Expression):
754        scope = build_scope(scope_or_expression)
755        if not isinstance(scope, Scope):
756            return
757    else:
758        scope = scope_or_expression
759
760    new_selections = []
761    for i, (selection, aliased_column) in enumerate(
762        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
763    ):
764        if selection is None:
765            break
766
767        if isinstance(selection, exp.Subquery):
768            if not selection.output_name:
769                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
770        elif not isinstance(selection, exp.Alias) and not selection.is_star:
771            selection = alias(
772                selection,
773                alias=selection.output_name or f"_col_{i}",
774                copy=False,
775            )
776        if aliased_column:
777            selection.set("alias", exp.to_identifier(aliased_column))
778
779        new_selections.append(selection)
780
781    if isinstance(scope.expression, exp.Select):
782        scope.expression.set("expressions", new_selections)
783
784
785def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
786    """Makes sure all identifiers that need to be quoted are quoted."""
787    return expression.transform(
788        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
789    )  # type: ignore
790
791
792def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
793    """
794    Pushes down the CTE alias columns into the projection,
795
796    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
797
798    Example:
799        >>> import sqlglot
800        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
801        >>> pushdown_cte_alias_columns(expression).sql()
802        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
803
804    Args:
805        expression: Expression to pushdown.
806
807    Returns:
808        The expression with the CTE aliases pushed down into the projection.
809    """
810    for cte in expression.find_all(exp.CTE):
811        if cte.alias_column_names:
812            new_expressions = []
813            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
814                if isinstance(projection, exp.Alias):
815                    projection.set("alias", _alias)
816                else:
817                    projection = alias(projection, alias=_alias)
818                new_expressions.append(projection)
819            cte.this.set("expressions", new_expressions)
820
821    return expression
822
823
824class Resolver:
825    """
826    Helper for resolving columns.
827
828    This is a class so we can lazily load some things and easily share them across functions.
829    """
830
831    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
832        self.scope = scope
833        self.schema = schema
834        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
835        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
836        self._all_columns: t.Optional[t.Set[str]] = None
837        self._infer_schema = infer_schema
838        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
839
840    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
841        """
842        Get the table for a column name.
843
844        Args:
845            column_name: The column name to find the table for.
846        Returns:
847            The table name if it can be found/inferred.
848        """
849        if self._unambiguous_columns is None:
850            self._unambiguous_columns = self._get_unambiguous_columns(
851                self._get_all_source_columns()
852            )
853
854        table_name = self._unambiguous_columns.get(column_name)
855
856        if not table_name and self._infer_schema:
857            sources_without_schema = tuple(
858                source
859                for source, columns in self._get_all_source_columns().items()
860                if not columns or "*" in columns
861            )
862            if len(sources_without_schema) == 1:
863                table_name = sources_without_schema[0]
864
865        if table_name not in self.scope.selected_sources:
866            return exp.to_identifier(table_name)
867
868        node, _ = self.scope.selected_sources.get(table_name)
869
870        if isinstance(node, exp.Query):
871            while node and node.alias != table_name:
872                node = node.parent
873
874        node_alias = node.args.get("alias")
875        if node_alias:
876            return exp.to_identifier(node_alias.this)
877
878        return exp.to_identifier(table_name)
879
880    @property
881    def all_columns(self) -> t.Set[str]:
882        """All available columns of all sources in this scope"""
883        if self._all_columns is None:
884            self._all_columns = {
885                column for columns in self._get_all_source_columns().values() for column in columns
886            }
887        return self._all_columns
888
889    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
890        """Resolve the source columns for a given source `name`."""
891        cache_key = (name, only_visible)
892        if cache_key not in self._get_source_columns_cache:
893            if name not in self.scope.sources:
894                raise OptimizeError(f"Unknown table: {name}")
895
896            source = self.scope.sources[name]
897
898            if isinstance(source, exp.Table):
899                columns = self.schema.column_names(source, only_visible)
900            elif isinstance(source, Scope) and isinstance(
901                source.expression, (exp.Values, exp.Unnest)
902            ):
903                columns = source.expression.named_selects
904
905                # in bigquery, unnest structs are automatically scoped as tables, so you can
906                # directly select a struct field in a query.
907                # this handles the case where the unnest is statically defined.
908                if self.schema.dialect == "bigquery":
909                    if source.expression.is_type(exp.DataType.Type.STRUCT):
910                        for k in source.expression.type.expressions:  # type: ignore
911                            columns.append(k.name)
912            else:
913                columns = source.expression.named_selects
914
915            node, _ = self.scope.selected_sources.get(name) or (None, None)
916            if isinstance(node, Scope):
917                column_aliases = node.expression.alias_column_names
918            elif isinstance(node, exp.Expression):
919                column_aliases = node.alias_column_names
920            else:
921                column_aliases = []
922
923            if column_aliases:
924                # If the source's columns are aliased, their aliases shadow the corresponding column names.
925                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
926                columns = [
927                    alias or name
928                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
929                ]
930
931            self._get_source_columns_cache[cache_key] = columns
932
933        return self._get_source_columns_cache[cache_key]
934
935    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
936        if self._source_columns is None:
937            self._source_columns = {
938                source_name: self.get_source_columns(source_name)
939                for source_name, source in itertools.chain(
940                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
941                )
942            }
943        return self._source_columns
944
945    def _get_unambiguous_columns(
946        self, source_columns: t.Dict[str, t.Sequence[str]]
947    ) -> t.Mapping[str, str]:
948        """
949        Find all the unambiguous columns in sources.
950
951        Args:
952            source_columns: Mapping of names to source columns.
953
954        Returns:
955            Mapping of column name to source name.
956        """
957        if not source_columns:
958            return {}
959
960        source_columns_pairs = list(source_columns.items())
961
962        first_table, first_columns = source_columns_pairs[0]
963
964        if len(source_columns_pairs) == 1:
965            # Performance optimization - avoid copying first_columns if there is only one table.
966            return SingleValuedMapping(first_columns, first_table)
967
968        unambiguous_columns = {col: first_table for col in first_columns}
969        all_columns = set(unambiguous_columns)
970
971        for table, columns in source_columns_pairs[1:]:
972            unique = set(columns)
973            ambiguous = all_columns.intersection(unique)
974            all_columns.update(columns)
975
976            for column in ambiguous:
977                unambiguous_columns.pop(column, None)
978            for column in unique.difference(ambiguous):
979                unambiguous_columns[column] = table
980
981        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27) -> exp.Expression:
 28    """
 29    Rewrite sqlglot AST to have fully qualified columns.
 30
 31    Example:
 32        >>> import sqlglot
 33        >>> schema = {"tbl": {"col": "INT"}}
 34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 35        >>> qualify_columns(expression, schema).sql()
 36        'SELECT tbl.col AS col FROM tbl'
 37
 38    Args:
 39        expression: Expression to qualify.
 40        schema: Database schema.
 41        expand_alias_refs: Whether to expand references to aliases.
 42        expand_stars: Whether to expand star queries. This is a necessary step
 43            for most of the optimizer's rules to work; do not set to False unless you
 44            know what you're doing!
 45        infer_schema: Whether to infer the schema if missing.
 46        allow_partial_qualification: Whether to allow partial qualification.
 47
 48    Returns:
 49        The qualified expression.
 50
 51    Notes:
 52        - Currently only handles a single PIVOT or UNPIVOT operator
 53    """
 54    schema = ensure_schema(schema)
 55    annotator = TypeAnnotator(schema)
 56    infer_schema = schema.empty if infer_schema is None else infer_schema
 57    dialect = Dialect.get_or_raise(schema.dialect)
 58    pseudocolumns = dialect.PSEUDOCOLUMNS
 59
 60    snowflake_or_oracle = dialect in ("oracle", "snowflake")
 61
 62    for scope in traverse_scope(expression):
 63        scope_expression = scope.expression
 64        is_select = isinstance(scope_expression, exp.Select)
 65
 66        if is_select and snowflake_or_oracle and scope_expression.args.get("connect"):
 67            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 68            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 69            scope_expression.transform(
 70                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 71                copy=False,
 72            )
 73            scope.clear_cache()
 74
 75        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 76        _pop_table_column_aliases(scope.ctes)
 77        _pop_table_column_aliases(scope.derived_tables)
 78        using_column_tables = _expand_using(scope, resolver)
 79
 80        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 81            _expand_alias_refs(
 82                scope,
 83                resolver,
 84                dialect,
 85                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 86            )
 87
 88        _convert_columns_to_dots(scope, resolver)
 89        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106
107        # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse)
108        # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT
109        _expand_order_by_and_distinct_on(scope, resolver)
110
111        if dialect == "bigquery":
112            annotator.annotate_scope(scope)
113
114    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E:
118    """Raise an `OptimizeError` if any columns aren't qualified"""
119    all_unqualified_columns = []
120    for scope in traverse_scope(expression):
121        if isinstance(scope.expression, exp.Select):
122            unqualified_columns = scope.unqualified_columns
123
124            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
125                column = scope.external_columns[0]
126                for_table = f" for table: '{column.table}'" if column.table else ""
127                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
128
129            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
130                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
131                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
132                # this list here to ensure those in the former category will be excluded.
133                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
134                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
135
136            all_unqualified_columns.extend(unqualified_columns)
137
138    if all_unqualified_columns:
139        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
140
141    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
752def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
753    """Ensure all output columns are aliased"""
754    if isinstance(scope_or_expression, exp.Expression):
755        scope = build_scope(scope_or_expression)
756        if not isinstance(scope, Scope):
757            return
758    else:
759        scope = scope_or_expression
760
761    new_selections = []
762    for i, (selection, aliased_column) in enumerate(
763        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
764    ):
765        if selection is None:
766            break
767
768        if isinstance(selection, exp.Subquery):
769            if not selection.output_name:
770                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
771        elif not isinstance(selection, exp.Alias) and not selection.is_star:
772            selection = alias(
773                selection,
774                alias=selection.output_name or f"_col_{i}",
775                copy=False,
776            )
777        if aliased_column:
778            selection.set("alias", exp.to_identifier(aliased_column))
779
780        new_selections.append(selection)
781
782    if isinstance(scope.expression, exp.Select):
783        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
786def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
787    """Makes sure all identifiers that need to be quoted are quoted."""
788    return expression.transform(
789        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
790    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
793def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
794    """
795    Pushes down the CTE alias columns into the projection,
796
797    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
798
799    Example:
800        >>> import sqlglot
801        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
802        >>> pushdown_cte_alias_columns(expression).sql()
803        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
804
805    Args:
806        expression: Expression to pushdown.
807
808    Returns:
809        The expression with the CTE aliases pushed down into the projection.
810    """
811    for cte in expression.find_all(exp.CTE):
812        if cte.alias_column_names:
813            new_expressions = []
814            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
815                if isinstance(projection, exp.Alias):
816                    projection.set("alias", _alias)
817                else:
818                    projection = alias(projection, alias=_alias)
819                new_expressions.append(projection)
820            cte.this.set("expressions", new_expressions)
821
822    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
825class Resolver:
826    """
827    Helper for resolving columns.
828
829    This is a class so we can lazily load some things and easily share them across functions.
830    """
831
832    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
833        self.scope = scope
834        self.schema = schema
835        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
836        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
837        self._all_columns: t.Optional[t.Set[str]] = None
838        self._infer_schema = infer_schema
839        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
840
841    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
842        """
843        Get the table for a column name.
844
845        Args:
846            column_name: The column name to find the table for.
847        Returns:
848            The table name if it can be found/inferred.
849        """
850        if self._unambiguous_columns is None:
851            self._unambiguous_columns = self._get_unambiguous_columns(
852                self._get_all_source_columns()
853            )
854
855        table_name = self._unambiguous_columns.get(column_name)
856
857        if not table_name and self._infer_schema:
858            sources_without_schema = tuple(
859                source
860                for source, columns in self._get_all_source_columns().items()
861                if not columns or "*" in columns
862            )
863            if len(sources_without_schema) == 1:
864                table_name = sources_without_schema[0]
865
866        if table_name not in self.scope.selected_sources:
867            return exp.to_identifier(table_name)
868
869        node, _ = self.scope.selected_sources.get(table_name)
870
871        if isinstance(node, exp.Query):
872            while node and node.alias != table_name:
873                node = node.parent
874
875        node_alias = node.args.get("alias")
876        if node_alias:
877            return exp.to_identifier(node_alias.this)
878
879        return exp.to_identifier(table_name)
880
881    @property
882    def all_columns(self) -> t.Set[str]:
883        """All available columns of all sources in this scope"""
884        if self._all_columns is None:
885            self._all_columns = {
886                column for columns in self._get_all_source_columns().values() for column in columns
887            }
888        return self._all_columns
889
890    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
891        """Resolve the source columns for a given source `name`."""
892        cache_key = (name, only_visible)
893        if cache_key not in self._get_source_columns_cache:
894            if name not in self.scope.sources:
895                raise OptimizeError(f"Unknown table: {name}")
896
897            source = self.scope.sources[name]
898
899            if isinstance(source, exp.Table):
900                columns = self.schema.column_names(source, only_visible)
901            elif isinstance(source, Scope) and isinstance(
902                source.expression, (exp.Values, exp.Unnest)
903            ):
904                columns = source.expression.named_selects
905
906                # in bigquery, unnest structs are automatically scoped as tables, so you can
907                # directly select a struct field in a query.
908                # this handles the case where the unnest is statically defined.
909                if self.schema.dialect == "bigquery":
910                    if source.expression.is_type(exp.DataType.Type.STRUCT):
911                        for k in source.expression.type.expressions:  # type: ignore
912                            columns.append(k.name)
913            else:
914                columns = source.expression.named_selects
915
916            node, _ = self.scope.selected_sources.get(name) or (None, None)
917            if isinstance(node, Scope):
918                column_aliases = node.expression.alias_column_names
919            elif isinstance(node, exp.Expression):
920                column_aliases = node.alias_column_names
921            else:
922                column_aliases = []
923
924            if column_aliases:
925                # If the source's columns are aliased, their aliases shadow the corresponding column names.
926                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
927                columns = [
928                    alias or name
929                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
930                ]
931
932            self._get_source_columns_cache[cache_key] = columns
933
934        return self._get_source_columns_cache[cache_key]
935
936    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
937        if self._source_columns is None:
938            self._source_columns = {
939                source_name: self.get_source_columns(source_name)
940                for source_name, source in itertools.chain(
941                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
942                )
943            }
944        return self._source_columns
945
946    def _get_unambiguous_columns(
947        self, source_columns: t.Dict[str, t.Sequence[str]]
948    ) -> t.Mapping[str, str]:
949        """
950        Find all the unambiguous columns in sources.
951
952        Args:
953            source_columns: Mapping of names to source columns.
954
955        Returns:
956            Mapping of column name to source name.
957        """
958        if not source_columns:
959            return {}
960
961        source_columns_pairs = list(source_columns.items())
962
963        first_table, first_columns = source_columns_pairs[0]
964
965        if len(source_columns_pairs) == 1:
966            # Performance optimization - avoid copying first_columns if there is only one table.
967            return SingleValuedMapping(first_columns, first_table)
968
969        unambiguous_columns = {col: first_table for col in first_columns}
970        all_columns = set(unambiguous_columns)
971
972        for table, columns in source_columns_pairs[1:]:
973            unique = set(columns)
974            ambiguous = all_columns.intersection(unique)
975            all_columns.update(columns)
976
977            for column in ambiguous:
978                unambiguous_columns.pop(column, None)
979            for column in unique.difference(ambiguous):
980                unambiguous_columns[column] = table
981
982        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
832    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
833        self.scope = scope
834        self.schema = schema
835        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
836        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
837        self._all_columns: t.Optional[t.Set[str]] = None
838        self._infer_schema = infer_schema
839        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
841    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
842        """
843        Get the table for a column name.
844
845        Args:
846            column_name: The column name to find the table for.
847        Returns:
848            The table name if it can be found/inferred.
849        """
850        if self._unambiguous_columns is None:
851            self._unambiguous_columns = self._get_unambiguous_columns(
852                self._get_all_source_columns()
853            )
854
855        table_name = self._unambiguous_columns.get(column_name)
856
857        if not table_name and self._infer_schema:
858            sources_without_schema = tuple(
859                source
860                for source, columns in self._get_all_source_columns().items()
861                if not columns or "*" in columns
862            )
863            if len(sources_without_schema) == 1:
864                table_name = sources_without_schema[0]
865
866        if table_name not in self.scope.selected_sources:
867            return exp.to_identifier(table_name)
868
869        node, _ = self.scope.selected_sources.get(table_name)
870
871        if isinstance(node, exp.Query):
872            while node and node.alias != table_name:
873                node = node.parent
874
875        node_alias = node.args.get("alias")
876        if node_alias:
877            return exp.to_identifier(node_alias.this)
878
879        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
881    @property
882    def all_columns(self) -> t.Set[str]:
883        """All available columns of all sources in this scope"""
884        if self._all_columns is None:
885            self._all_columns = {
886                column for columns in self._get_all_source_columns().values() for column in columns
887            }
888        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
890    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
891        """Resolve the source columns for a given source `name`."""
892        cache_key = (name, only_visible)
893        if cache_key not in self._get_source_columns_cache:
894            if name not in self.scope.sources:
895                raise OptimizeError(f"Unknown table: {name}")
896
897            source = self.scope.sources[name]
898
899            if isinstance(source, exp.Table):
900                columns = self.schema.column_names(source, only_visible)
901            elif isinstance(source, Scope) and isinstance(
902                source.expression, (exp.Values, exp.Unnest)
903            ):
904                columns = source.expression.named_selects
905
906                # in bigquery, unnest structs are automatically scoped as tables, so you can
907                # directly select a struct field in a query.
908                # this handles the case where the unnest is statically defined.
909                if self.schema.dialect == "bigquery":
910                    if source.expression.is_type(exp.DataType.Type.STRUCT):
911                        for k in source.expression.type.expressions:  # type: ignore
912                            columns.append(k.name)
913            else:
914                columns = source.expression.named_selects
915
916            node, _ = self.scope.selected_sources.get(name) or (None, None)
917            if isinstance(node, Scope):
918                column_aliases = node.expression.alias_column_names
919            elif isinstance(node, exp.Expression):
920                column_aliases = node.alias_column_names
921            else:
922                column_aliases = []
923
924            if column_aliases:
925                # If the source's columns are aliased, their aliases shadow the corresponding column names.
926                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
927                columns = [
928                    alias or name
929                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
930                ]
931
932            self._get_source_columns_cache[cache_key] = columns
933
934        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.