Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25    allow_partial_qualification: bool = False,
 26) -> exp.Expression:
 27    """
 28    Rewrite sqlglot AST to have fully qualified columns.
 29
 30    Example:
 31        >>> import sqlglot
 32        >>> schema = {"tbl": {"col": "INT"}}
 33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 34        >>> qualify_columns(expression, schema).sql()
 35        'SELECT tbl.col AS col FROM tbl'
 36
 37    Args:
 38        expression: Expression to qualify.
 39        schema: Database schema.
 40        expand_alias_refs: Whether to expand references to aliases.
 41        expand_stars: Whether to expand star queries. This is a necessary step
 42            for most of the optimizer's rules to work; do not set to False unless you
 43            know what you're doing!
 44        infer_schema: Whether to infer the schema if missing.
 45        allow_partial_qualification: Whether to allow partial qualification.
 46
 47    Returns:
 48        The qualified expression.
 49
 50    Notes:
 51        - Currently only handles a single PIVOT or UNPIVOT operator
 52    """
 53    schema = ensure_schema(schema)
 54    annotator = TypeAnnotator(schema)
 55    infer_schema = schema.empty if infer_schema is None else infer_schema
 56    dialect = Dialect.get_or_raise(schema.dialect)
 57    pseudocolumns = dialect.PSEUDOCOLUMNS
 58
 59    snowflake_or_oracle = dialect in ("oracle", "snowflake")
 60
 61    for scope in traverse_scope(expression):
 62        scope_expression = scope.expression
 63        is_select = isinstance(scope_expression, exp.Select)
 64
 65        if is_select and snowflake_or_oracle and scope_expression.args.get("connect"):
 66            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 67            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 68            scope_expression.transform(
 69                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 70                copy=False,
 71            )
 72            scope.clear_cache()
 73
 74        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 75        _pop_table_column_aliases(scope.ctes)
 76        _pop_table_column_aliases(scope.derived_tables)
 77        using_column_tables = _expand_using(scope, resolver)
 78
 79        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 80            _expand_alias_refs(
 81                scope,
 82                resolver,
 83                dialect,
 84                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 85            )
 86
 87        _convert_columns_to_dots(scope, resolver)
 88        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 89
 90        if not schema.empty and expand_alias_refs:
 91            _expand_alias_refs(scope, resolver, dialect)
 92
 93        if is_select:
 94            if expand_stars:
 95                _expand_stars(
 96                    scope,
 97                    resolver,
 98                    using_column_tables,
 99                    pseudocolumns,
100                    annotator,
101                )
102            qualify_outputs(scope)
103
104        _expand_group_by(scope, dialect)
105        _expand_order_by(scope, resolver)
106
107        if dialect == "bigquery":
108            annotator.annotate_scope(scope)
109
110    return expression
111
112
113def validate_qualify_columns(expression: E) -> E:
114    """Raise an `OptimizeError` if any columns aren't qualified"""
115    all_unqualified_columns = []
116    for scope in traverse_scope(expression):
117        if isinstance(scope.expression, exp.Select):
118            unqualified_columns = scope.unqualified_columns
119
120            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
121                column = scope.external_columns[0]
122                for_table = f" for table: '{column.table}'" if column.table else ""
123                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
124
125            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
126                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
127                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
128                # this list here to ensure those in the former category will be excluded.
129                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
130                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
131
132            all_unqualified_columns.extend(unqualified_columns)
133
134    if all_unqualified_columns:
135        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
136
137    return expression
138
139
140def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
141    name_column = []
142    field = unpivot.args.get("field")
143    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
144        name_column.append(field.this)
145
146    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
147    return itertools.chain(name_column, value_columns)
148
149
150def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
151    """
152    Remove table column aliases.
153
154    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
155    """
156    for derived_table in derived_tables:
157        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
158            continue
159        table_alias = derived_table.args.get("alias")
160        if table_alias:
161            table_alias.args.pop("columns", None)
162
163
164def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
165    columns = {}
166
167    def _update_source_columns(source_name: str) -> None:
168        for column_name in resolver.get_source_columns(source_name):
169            if column_name not in columns:
170                columns[column_name] = source_name
171
172    joins = list(scope.find_all(exp.Join))
173    names = {join.alias_or_name for join in joins}
174    ordered = [key for key in scope.selected_sources if key not in names]
175
176    if names and not ordered:
177        raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
178
179    # Mapping of automatically joined column names to an ordered set of source names (dict).
180    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
181
182    for source_name in ordered:
183        _update_source_columns(source_name)
184
185    for i, join in enumerate(joins):
186        source_table = ordered[-1]
187        if source_table:
188            _update_source_columns(source_table)
189
190        join_table = join.alias_or_name
191        ordered.append(join_table)
192
193        using = join.args.get("using")
194        if not using:
195            continue
196
197        join_columns = resolver.get_source_columns(join_table)
198        conditions = []
199        using_identifier_count = len(using)
200        is_semi_or_anti_join = join.is_semi_or_anti_join
201
202        for identifier in using:
203            identifier = identifier.name
204            table = columns.get(identifier)
205
206            if not table or identifier not in join_columns:
207                if (columns and "*" not in columns) and join_columns:
208                    raise OptimizeError(f"Cannot automatically join: {identifier}")
209
210            table = table or source_table
211
212            if i == 0 or using_identifier_count == 1:
213                lhs: exp.Expression = exp.column(identifier, table=table)
214            else:
215                coalesce_columns = [
216                    exp.column(identifier, table=t)
217                    for t in ordered[:-1]
218                    if identifier in resolver.get_source_columns(t)
219                ]
220                if len(coalesce_columns) > 1:
221                    lhs = exp.func("coalesce", *coalesce_columns)
222                else:
223                    lhs = exp.column(identifier, table=table)
224
225            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
226
227            # Set all values in the dict to None, because we only care about the key ordering
228            tables = column_tables.setdefault(identifier, {})
229
230            # Do not update the dict if this was a SEMI/ANTI join in
231            # order to avoid generating COALESCE columns for this join pair
232            if not is_semi_or_anti_join:
233                if table not in tables:
234                    tables[table] = None
235                if join_table not in tables:
236                    tables[join_table] = None
237
238        join.args.pop("using")
239        join.set("on", exp.and_(*conditions, copy=False))
240
241    if column_tables:
242        for column in scope.columns:
243            if not column.table and column.name in column_tables:
244                tables = column_tables[column.name]
245                coalesce_args = [exp.column(column.name, table=table) for table in tables]
246                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
247
248                if isinstance(column.parent, exp.Select):
249                    # Ensure the USING column keeps its name if it's projected
250                    replacement = alias(replacement, alias=column.name, copy=False)
251                elif isinstance(column.parent, exp.Struct):
252                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
253                    replacement = exp.PropertyEQ(
254                        this=exp.to_identifier(column.name), expression=replacement
255                    )
256
257                scope.replace(column, replacement)
258
259    return column_tables
260
261
262def _expand_alias_refs(
263    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
264) -> None:
265    """
266    Expand references to aliases.
267    Example:
268        SELECT y.foo AS bar, bar * 2 AS baz FROM y
269     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
270    """
271    expression = scope.expression
272
273    if not isinstance(expression, exp.Select):
274        return
275
276    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
277
278    def replace_columns(
279        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
280    ) -> None:
281        is_group_by = isinstance(node, exp.Group)
282        if not node or (expand_only_groupby and not is_group_by):
283            return
284
285        for column in walk_in_scope(node, prune=lambda node: node.is_star):
286            if not isinstance(column, exp.Column):
287                continue
288
289            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
290            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
291            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
292            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
293            if expand_only_groupby and is_group_by and column.parent is not node:
294                continue
295
296            table = resolver.get_table(column.name) if resolve_table and not column.table else None
297            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
298            double_agg = (
299                (
300                    alias_expr.find(exp.AggFunc)
301                    and (
302                        column.find_ancestor(exp.AggFunc)
303                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
304                    )
305                )
306                if alias_expr
307                else False
308            )
309
310            if table and (not alias_expr or double_agg):
311                column.set("table", table)
312            elif not column.table and alias_expr and not double_agg:
313                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
314                    if literal_index:
315                        column.replace(exp.Literal.number(i))
316                else:
317                    column = column.replace(exp.paren(alias_expr))
318                    simplified = simplify_parens(column)
319                    if simplified is not column:
320                        column.replace(simplified)
321
322    for i, projection in enumerate(expression.selects):
323        replace_columns(projection)
324        if isinstance(projection, exp.Alias):
325            alias_to_expression[projection.alias] = (projection.this, i + 1)
326
327    parent_scope = scope
328    while parent_scope.is_union:
329        parent_scope = parent_scope.parent
330
331    # We shouldn't expand aliases if they match the recursive CTE's columns
332    if parent_scope.is_cte:
333        cte = parent_scope.expression.parent
334        if cte.find_ancestor(exp.With).recursive:
335            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
336                alias_to_expression.pop(recursive_cte_column.output_name, None)
337
338    replace_columns(expression.args.get("where"))
339    replace_columns(expression.args.get("group"), literal_index=True)
340    replace_columns(expression.args.get("having"), resolve_table=True)
341    replace_columns(expression.args.get("qualify"), resolve_table=True)
342
343    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
344    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
345    if dialect == "snowflake":
346        for join in expression.args.get("joins") or []:
347            replace_columns(join)
348
349    scope.clear_cache()
350
351
352def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
353    expression = scope.expression
354    group = expression.args.get("group")
355    if not group:
356        return
357
358    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
359    expression.set("group", group)
360
361
362def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
363    order = scope.expression.args.get("order")
364    if not order:
365        return
366
367    ordereds = order.expressions
368    for ordered, new_expression in zip(
369        ordereds,
370        _expand_positional_references(
371            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
372        ),
373    ):
374        for agg in ordered.find_all(exp.AggFunc):
375            for col in agg.find_all(exp.Column):
376                if not col.table:
377                    col.set("table", resolver.get_table(col.name))
378
379        ordered.set("this", new_expression)
380
381    if scope.expression.args.get("group"):
382        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
383
384        for ordered in ordereds:
385            ordered = ordered.this
386
387            ordered.replace(
388                exp.to_identifier(_select_by_pos(scope, ordered).alias)
389                if ordered.is_int
390                else selects.get(ordered, ordered)
391            )
392
393
394def _expand_positional_references(
395    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
396) -> t.List[exp.Expression]:
397    new_nodes: t.List[exp.Expression] = []
398    ambiguous_projections = None
399
400    for node in expressions:
401        if node.is_int:
402            select = _select_by_pos(scope, t.cast(exp.Literal, node))
403
404            if alias:
405                new_nodes.append(exp.column(select.args["alias"].copy()))
406            else:
407                select = select.this
408
409                if dialect == "bigquery":
410                    if ambiguous_projections is None:
411                        # When a projection name is also a source name and it is referenced in the
412                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
413                        ambiguous_projections = {
414                            s.alias_or_name
415                            for s in scope.expression.selects
416                            if s.alias_or_name in scope.selected_sources
417                        }
418
419                    ambiguous = any(
420                        column.parts[0].name in ambiguous_projections
421                        for column in select.find_all(exp.Column)
422                    )
423                else:
424                    ambiguous = False
425
426                if (
427                    isinstance(select, exp.CONSTANTS)
428                    or select.find(exp.Explode, exp.Unnest)
429                    or ambiguous
430                ):
431                    new_nodes.append(node)
432                else:
433                    new_nodes.append(select.copy())
434        else:
435            new_nodes.append(node)
436
437    return new_nodes
438
439
440def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
441    try:
442        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
443    except IndexError:
444        raise OptimizeError(f"Unknown output column: {node.name}")
445
446
447def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
448    """
449    Converts `Column` instances that represent struct field lookup into chained `Dots`.
450
451    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
452    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
453    """
454    converted = False
455    for column in itertools.chain(scope.columns, scope.stars):
456        if isinstance(column, exp.Dot):
457            continue
458
459        column_table: t.Optional[str | exp.Identifier] = column.table
460        if (
461            column_table
462            and column_table not in scope.sources
463            and (
464                not scope.parent
465                or column_table not in scope.parent.sources
466                or not scope.is_correlated_subquery
467            )
468        ):
469            root, *parts = column.parts
470
471            if root.name in scope.sources:
472                # The struct is already qualified, but we still need to change the AST
473                column_table = root
474                root, *parts = parts
475            else:
476                column_table = resolver.get_table(root.name)
477
478            if column_table:
479                converted = True
480                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
481
482    if converted:
483        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
484        # a `for column in scope.columns` iteration, even though they shouldn't be
485        scope.clear_cache()
486
487
488def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
489    """Disambiguate columns, ensuring each column specifies a source"""
490    for column in scope.columns:
491        column_table = column.table
492        column_name = column.name
493
494        if column_table and column_table in scope.sources:
495            source_columns = resolver.get_source_columns(column_table)
496            if (
497                not allow_partial_qualification
498                and source_columns
499                and column_name not in source_columns
500                and "*" not in source_columns
501            ):
502                raise OptimizeError(f"Unknown column: {column_name}")
503
504        if not column_table:
505            if scope.pivots and not column.find_ancestor(exp.Pivot):
506                # If the column is under the Pivot expression, we need to qualify it
507                # using the name of the pivoted source instead of the pivot's alias
508                column.set("table", exp.to_identifier(scope.pivots[0].alias))
509                continue
510
511            # column_table can be a '' because bigquery unnest has no table alias
512            column_table = resolver.get_table(column_name)
513            if column_table:
514                column.set("table", column_table)
515
516    for pivot in scope.pivots:
517        for column in pivot.find_all(exp.Column):
518            if not column.table and column.name in resolver.all_columns:
519                column_table = resolver.get_table(column.name)
520                if column_table:
521                    column.set("table", column_table)
522
523
524def _expand_struct_stars(
525    expression: exp.Dot,
526) -> t.List[exp.Alias]:
527    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
528
529    dot_column = t.cast(exp.Column, expression.find(exp.Column))
530    if not dot_column.is_type(exp.DataType.Type.STRUCT):
531        return []
532
533    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
534    dot_column = dot_column.copy()
535    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
536
537    # First part is the table name and last part is the star so they can be dropped
538    dot_parts = expression.parts[1:-1]
539
540    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
541    for part in dot_parts[1:]:
542        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
543            # Unable to expand star unless all fields are named
544            if not isinstance(field.this, exp.Identifier):
545                return []
546
547            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
548                starting_struct = field
549                break
550        else:
551            # There is no matching field in the struct
552            return []
553
554    taken_names = set()
555    new_selections = []
556
557    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
558        name = field.name
559
560        # Ambiguous or anonymous fields can't be expanded
561        if name in taken_names or not isinstance(field.this, exp.Identifier):
562            return []
563
564        taken_names.add(name)
565
566        this = field.this.copy()
567        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
568        new_column = exp.column(
569            t.cast(exp.Identifier, root),
570            table=dot_column.args.get("table"),
571            fields=t.cast(t.List[exp.Identifier], parts),
572        )
573        new_selections.append(alias(new_column, this, copy=False))
574
575    return new_selections
576
577
578def _expand_stars(
579    scope: Scope,
580    resolver: Resolver,
581    using_column_tables: t.Dict[str, t.Any],
582    pseudocolumns: t.Set[str],
583    annotator: TypeAnnotator,
584) -> None:
585    """Expand stars to lists of column selections"""
586
587    new_selections: t.List[exp.Expression] = []
588    except_columns: t.Dict[int, t.Set[str]] = {}
589    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
590    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
591
592    coalesced_columns = set()
593    dialect = resolver.schema.dialect
594
595    pivot_output_columns = None
596    pivot_exclude_columns = None
597
598    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
599    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
600        if pivot.unpivot:
601            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
602
603            field = pivot.args.get("field")
604            if isinstance(field, exp.In):
605                pivot_exclude_columns = {
606                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
607                }
608        else:
609            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
610
611            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
612            if not pivot_output_columns:
613                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
614
615    is_bigquery = dialect == "bigquery"
616    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
617        # Found struct expansion, annotate scope ahead of time
618        annotator.annotate_scope(scope)
619
620    for expression in scope.expression.selects:
621        tables = []
622        if isinstance(expression, exp.Star):
623            tables.extend(scope.selected_sources)
624            _add_except_columns(expression, tables, except_columns)
625            _add_replace_columns(expression, tables, replace_columns)
626            _add_rename_columns(expression, tables, rename_columns)
627        elif expression.is_star:
628            if not isinstance(expression, exp.Dot):
629                tables.append(expression.table)
630                _add_except_columns(expression.this, tables, except_columns)
631                _add_replace_columns(expression.this, tables, replace_columns)
632                _add_rename_columns(expression.this, tables, rename_columns)
633            elif is_bigquery:
634                struct_fields = _expand_struct_stars(expression)
635                if struct_fields:
636                    new_selections.extend(struct_fields)
637                    continue
638
639        if not tables:
640            new_selections.append(expression)
641            continue
642
643        for table in tables:
644            if table not in scope.sources:
645                raise OptimizeError(f"Unknown table: {table}")
646
647            columns = resolver.get_source_columns(table, only_visible=True)
648            columns = columns or scope.outer_columns
649
650            if pseudocolumns:
651                columns = [name for name in columns if name.upper() not in pseudocolumns]
652
653            if not columns or "*" in columns:
654                return
655
656            table_id = id(table)
657            columns_to_exclude = except_columns.get(table_id) or set()
658            renamed_columns = rename_columns.get(table_id, {})
659            replaced_columns = replace_columns.get(table_id, {})
660
661            if pivot:
662                if pivot_output_columns and pivot_exclude_columns:
663                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
664                    pivot_columns.extend(pivot_output_columns)
665                else:
666                    pivot_columns = pivot.alias_column_names
667
668                if pivot_columns:
669                    new_selections.extend(
670                        alias(exp.column(name, table=pivot.alias), name, copy=False)
671                        for name in pivot_columns
672                        if name not in columns_to_exclude
673                    )
674                    continue
675
676            for name in columns:
677                if name in columns_to_exclude or name in coalesced_columns:
678                    continue
679                if name in using_column_tables and table in using_column_tables[name]:
680                    coalesced_columns.add(name)
681                    tables = using_column_tables[name]
682                    coalesce_args = [exp.column(name, table=table) for table in tables]
683
684                    new_selections.append(
685                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
686                    )
687                else:
688                    alias_ = renamed_columns.get(name, name)
689                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
690                    new_selections.append(
691                        alias(selection_expr, alias_, copy=False)
692                        if alias_ != name
693                        else selection_expr
694                    )
695
696    # Ensures we don't overwrite the initial selections with an empty list
697    if new_selections and isinstance(scope.expression, exp.Select):
698        scope.expression.set("expressions", new_selections)
699
700
701def _add_except_columns(
702    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
703) -> None:
704    except_ = expression.args.get("except")
705
706    if not except_:
707        return
708
709    columns = {e.name for e in except_}
710
711    for table in tables:
712        except_columns[id(table)] = columns
713
714
715def _add_rename_columns(
716    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
717) -> None:
718    rename = expression.args.get("rename")
719
720    if not rename:
721        return
722
723    columns = {e.this.name: e.alias for e in rename}
724
725    for table in tables:
726        rename_columns[id(table)] = columns
727
728
729def _add_replace_columns(
730    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
731) -> None:
732    replace = expression.args.get("replace")
733
734    if not replace:
735        return
736
737    columns = {e.alias: e for e in replace}
738
739    for table in tables:
740        replace_columns[id(table)] = columns
741
742
743def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
744    """Ensure all output columns are aliased"""
745    if isinstance(scope_or_expression, exp.Expression):
746        scope = build_scope(scope_or_expression)
747        if not isinstance(scope, Scope):
748            return
749    else:
750        scope = scope_or_expression
751
752    new_selections = []
753    for i, (selection, aliased_column) in enumerate(
754        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
755    ):
756        if selection is None:
757            break
758
759        if isinstance(selection, exp.Subquery):
760            if not selection.output_name:
761                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
762        elif not isinstance(selection, exp.Alias) and not selection.is_star:
763            selection = alias(
764                selection,
765                alias=selection.output_name or f"_col_{i}",
766                copy=False,
767            )
768        if aliased_column:
769            selection.set("alias", exp.to_identifier(aliased_column))
770
771        new_selections.append(selection)
772
773    if isinstance(scope.expression, exp.Select):
774        scope.expression.set("expressions", new_selections)
775
776
777def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
778    """Makes sure all identifiers that need to be quoted are quoted."""
779    return expression.transform(
780        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
781    )  # type: ignore
782
783
784def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
785    """
786    Pushes down the CTE alias columns into the projection,
787
788    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
789
790    Example:
791        >>> import sqlglot
792        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
793        >>> pushdown_cte_alias_columns(expression).sql()
794        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
795
796    Args:
797        expression: Expression to pushdown.
798
799    Returns:
800        The expression with the CTE aliases pushed down into the projection.
801    """
802    for cte in expression.find_all(exp.CTE):
803        if cte.alias_column_names:
804            new_expressions = []
805            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
806                if isinstance(projection, exp.Alias):
807                    projection.set("alias", _alias)
808                else:
809                    projection = alias(projection, alias=_alias)
810                new_expressions.append(projection)
811            cte.this.set("expressions", new_expressions)
812
813    return expression
814
815
816class Resolver:
817    """
818    Helper for resolving columns.
819
820    This is a class so we can lazily load some things and easily share them across functions.
821    """
822
823    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
824        self.scope = scope
825        self.schema = schema
826        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
827        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
828        self._all_columns: t.Optional[t.Set[str]] = None
829        self._infer_schema = infer_schema
830        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
831
832    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
833        """
834        Get the table for a column name.
835
836        Args:
837            column_name: The column name to find the table for.
838        Returns:
839            The table name if it can be found/inferred.
840        """
841        if self._unambiguous_columns is None:
842            self._unambiguous_columns = self._get_unambiguous_columns(
843                self._get_all_source_columns()
844            )
845
846        table_name = self._unambiguous_columns.get(column_name)
847
848        if not table_name and self._infer_schema:
849            sources_without_schema = tuple(
850                source
851                for source, columns in self._get_all_source_columns().items()
852                if not columns or "*" in columns
853            )
854            if len(sources_without_schema) == 1:
855                table_name = sources_without_schema[0]
856
857        if table_name not in self.scope.selected_sources:
858            return exp.to_identifier(table_name)
859
860        node, _ = self.scope.selected_sources.get(table_name)
861
862        if isinstance(node, exp.Query):
863            while node and node.alias != table_name:
864                node = node.parent
865
866        node_alias = node.args.get("alias")
867        if node_alias:
868            return exp.to_identifier(node_alias.this)
869
870        return exp.to_identifier(table_name)
871
872    @property
873    def all_columns(self) -> t.Set[str]:
874        """All available columns of all sources in this scope"""
875        if self._all_columns is None:
876            self._all_columns = {
877                column for columns in self._get_all_source_columns().values() for column in columns
878            }
879        return self._all_columns
880
881    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
882        """Resolve the source columns for a given source `name`."""
883        cache_key = (name, only_visible)
884        if cache_key not in self._get_source_columns_cache:
885            if name not in self.scope.sources:
886                raise OptimizeError(f"Unknown table: {name}")
887
888            source = self.scope.sources[name]
889
890            if isinstance(source, exp.Table):
891                columns = self.schema.column_names(source, only_visible)
892            elif isinstance(source, Scope) and isinstance(
893                source.expression, (exp.Values, exp.Unnest)
894            ):
895                columns = source.expression.named_selects
896
897                # in bigquery, unnest structs are automatically scoped as tables, so you can
898                # directly select a struct field in a query.
899                # this handles the case where the unnest is statically defined.
900                if self.schema.dialect == "bigquery":
901                    if source.expression.is_type(exp.DataType.Type.STRUCT):
902                        for k in source.expression.type.expressions:  # type: ignore
903                            columns.append(k.name)
904            else:
905                columns = source.expression.named_selects
906
907            node, _ = self.scope.selected_sources.get(name) or (None, None)
908            if isinstance(node, Scope):
909                column_aliases = node.expression.alias_column_names
910            elif isinstance(node, exp.Expression):
911                column_aliases = node.alias_column_names
912            else:
913                column_aliases = []
914
915            if column_aliases:
916                # If the source's columns are aliased, their aliases shadow the corresponding column names.
917                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
918                columns = [
919                    alias or name
920                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
921                ]
922
923            self._get_source_columns_cache[cache_key] = columns
924
925        return self._get_source_columns_cache[cache_key]
926
927    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
928        if self._source_columns is None:
929            self._source_columns = {
930                source_name: self.get_source_columns(source_name)
931                for source_name, source in itertools.chain(
932                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
933                )
934            }
935        return self._source_columns
936
937    def _get_unambiguous_columns(
938        self, source_columns: t.Dict[str, t.Sequence[str]]
939    ) -> t.Mapping[str, str]:
940        """
941        Find all the unambiguous columns in sources.
942
943        Args:
944            source_columns: Mapping of names to source columns.
945
946        Returns:
947            Mapping of column name to source name.
948        """
949        if not source_columns:
950            return {}
951
952        source_columns_pairs = list(source_columns.items())
953
954        first_table, first_columns = source_columns_pairs[0]
955
956        if len(source_columns_pairs) == 1:
957            # Performance optimization - avoid copying first_columns if there is only one table.
958            return SingleValuedMapping(first_columns, first_table)
959
960        unambiguous_columns = {col: first_table for col in first_columns}
961        all_columns = set(unambiguous_columns)
962
963        for table, columns in source_columns_pairs[1:]:
964            unique = set(columns)
965            ambiguous = all_columns.intersection(unique)
966            all_columns.update(columns)
967
968            for column in ambiguous:
969                unambiguous_columns.pop(column, None)
970            for column in unique.difference(ambiguous):
971                unambiguous_columns[column] = table
972
973        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
 20def qualify_columns(
 21    expression: exp.Expression,
 22    schema: t.Dict | Schema,
 23    expand_alias_refs: bool = True,
 24    expand_stars: bool = True,
 25    infer_schema: t.Optional[bool] = None,
 26    allow_partial_qualification: bool = False,
 27) -> exp.Expression:
 28    """
 29    Rewrite sqlglot AST to have fully qualified columns.
 30
 31    Example:
 32        >>> import sqlglot
 33        >>> schema = {"tbl": {"col": "INT"}}
 34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 35        >>> qualify_columns(expression, schema).sql()
 36        'SELECT tbl.col AS col FROM tbl'
 37
 38    Args:
 39        expression: Expression to qualify.
 40        schema: Database schema.
 41        expand_alias_refs: Whether to expand references to aliases.
 42        expand_stars: Whether to expand star queries. This is a necessary step
 43            for most of the optimizer's rules to work; do not set to False unless you
 44            know what you're doing!
 45        infer_schema: Whether to infer the schema if missing.
 46        allow_partial_qualification: Whether to allow partial qualification.
 47
 48    Returns:
 49        The qualified expression.
 50
 51    Notes:
 52        - Currently only handles a single PIVOT or UNPIVOT operator
 53    """
 54    schema = ensure_schema(schema)
 55    annotator = TypeAnnotator(schema)
 56    infer_schema = schema.empty if infer_schema is None else infer_schema
 57    dialect = Dialect.get_or_raise(schema.dialect)
 58    pseudocolumns = dialect.PSEUDOCOLUMNS
 59
 60    snowflake_or_oracle = dialect in ("oracle", "snowflake")
 61
 62    for scope in traverse_scope(expression):
 63        scope_expression = scope.expression
 64        is_select = isinstance(scope_expression, exp.Select)
 65
 66        if is_select and snowflake_or_oracle and scope_expression.args.get("connect"):
 67            # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
 68            # pseudocolumn, which doesn't belong to a table, so we change it into an identifier
 69            scope_expression.transform(
 70                lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
 71                copy=False,
 72            )
 73            scope.clear_cache()
 74
 75        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 76        _pop_table_column_aliases(scope.ctes)
 77        _pop_table_column_aliases(scope.derived_tables)
 78        using_column_tables = _expand_using(scope, resolver)
 79
 80        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 81            _expand_alias_refs(
 82                scope,
 83                resolver,
 84                dialect,
 85                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 86            )
 87
 88        _convert_columns_to_dots(scope, resolver)
 89        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 90
 91        if not schema.empty and expand_alias_refs:
 92            _expand_alias_refs(scope, resolver, dialect)
 93
 94        if is_select:
 95            if expand_stars:
 96                _expand_stars(
 97                    scope,
 98                    resolver,
 99                    using_column_tables,
100                    pseudocolumns,
101                    annotator,
102                )
103            qualify_outputs(scope)
104
105        _expand_group_by(scope, dialect)
106        _expand_order_by(scope, resolver)
107
108        if dialect == "bigquery":
109            annotator.annotate_scope(scope)
110
111    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
114def validate_qualify_columns(expression: E) -> E:
115    """Raise an `OptimizeError` if any columns aren't qualified"""
116    all_unqualified_columns = []
117    for scope in traverse_scope(expression):
118        if isinstance(scope.expression, exp.Select):
119            unqualified_columns = scope.unqualified_columns
120
121            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
122                column = scope.external_columns[0]
123                for_table = f" for table: '{column.table}'" if column.table else ""
124                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
125
126            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
127                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
128                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
129                # this list here to ensure those in the former category will be excluded.
130                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
131                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
132
133            all_unqualified_columns.extend(unqualified_columns)
134
135    if all_unqualified_columns:
136        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
137
138    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
744def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
745    """Ensure all output columns are aliased"""
746    if isinstance(scope_or_expression, exp.Expression):
747        scope = build_scope(scope_or_expression)
748        if not isinstance(scope, Scope):
749            return
750    else:
751        scope = scope_or_expression
752
753    new_selections = []
754    for i, (selection, aliased_column) in enumerate(
755        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
756    ):
757        if selection is None:
758            break
759
760        if isinstance(selection, exp.Subquery):
761            if not selection.output_name:
762                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
763        elif not isinstance(selection, exp.Alias) and not selection.is_star:
764            selection = alias(
765                selection,
766                alias=selection.output_name or f"_col_{i}",
767                copy=False,
768            )
769        if aliased_column:
770            selection.set("alias", exp.to_identifier(aliased_column))
771
772        new_selections.append(selection)
773
774    if isinstance(scope.expression, exp.Select):
775        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
778def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
779    """Makes sure all identifiers that need to be quoted are quoted."""
780    return expression.transform(
781        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
782    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
785def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
786    """
787    Pushes down the CTE alias columns into the projection,
788
789    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
790
791    Example:
792        >>> import sqlglot
793        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
794        >>> pushdown_cte_alias_columns(expression).sql()
795        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
796
797    Args:
798        expression: Expression to pushdown.
799
800    Returns:
801        The expression with the CTE aliases pushed down into the projection.
802    """
803    for cte in expression.find_all(exp.CTE):
804        if cte.alias_column_names:
805            new_expressions = []
806            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
807                if isinstance(projection, exp.Alias):
808                    projection.set("alias", _alias)
809                else:
810                    projection = alias(projection, alias=_alias)
811                new_expressions.append(projection)
812            cte.this.set("expressions", new_expressions)
813
814    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
817class Resolver:
818    """
819    Helper for resolving columns.
820
821    This is a class so we can lazily load some things and easily share them across functions.
822    """
823
824    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
825        self.scope = scope
826        self.schema = schema
827        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
828        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
829        self._all_columns: t.Optional[t.Set[str]] = None
830        self._infer_schema = infer_schema
831        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
832
833    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
834        """
835        Get the table for a column name.
836
837        Args:
838            column_name: The column name to find the table for.
839        Returns:
840            The table name if it can be found/inferred.
841        """
842        if self._unambiguous_columns is None:
843            self._unambiguous_columns = self._get_unambiguous_columns(
844                self._get_all_source_columns()
845            )
846
847        table_name = self._unambiguous_columns.get(column_name)
848
849        if not table_name and self._infer_schema:
850            sources_without_schema = tuple(
851                source
852                for source, columns in self._get_all_source_columns().items()
853                if not columns or "*" in columns
854            )
855            if len(sources_without_schema) == 1:
856                table_name = sources_without_schema[0]
857
858        if table_name not in self.scope.selected_sources:
859            return exp.to_identifier(table_name)
860
861        node, _ = self.scope.selected_sources.get(table_name)
862
863        if isinstance(node, exp.Query):
864            while node and node.alias != table_name:
865                node = node.parent
866
867        node_alias = node.args.get("alias")
868        if node_alias:
869            return exp.to_identifier(node_alias.this)
870
871        return exp.to_identifier(table_name)
872
873    @property
874    def all_columns(self) -> t.Set[str]:
875        """All available columns of all sources in this scope"""
876        if self._all_columns is None:
877            self._all_columns = {
878                column for columns in self._get_all_source_columns().values() for column in columns
879            }
880        return self._all_columns
881
882    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
883        """Resolve the source columns for a given source `name`."""
884        cache_key = (name, only_visible)
885        if cache_key not in self._get_source_columns_cache:
886            if name not in self.scope.sources:
887                raise OptimizeError(f"Unknown table: {name}")
888
889            source = self.scope.sources[name]
890
891            if isinstance(source, exp.Table):
892                columns = self.schema.column_names(source, only_visible)
893            elif isinstance(source, Scope) and isinstance(
894                source.expression, (exp.Values, exp.Unnest)
895            ):
896                columns = source.expression.named_selects
897
898                # in bigquery, unnest structs are automatically scoped as tables, so you can
899                # directly select a struct field in a query.
900                # this handles the case where the unnest is statically defined.
901                if self.schema.dialect == "bigquery":
902                    if source.expression.is_type(exp.DataType.Type.STRUCT):
903                        for k in source.expression.type.expressions:  # type: ignore
904                            columns.append(k.name)
905            else:
906                columns = source.expression.named_selects
907
908            node, _ = self.scope.selected_sources.get(name) or (None, None)
909            if isinstance(node, Scope):
910                column_aliases = node.expression.alias_column_names
911            elif isinstance(node, exp.Expression):
912                column_aliases = node.alias_column_names
913            else:
914                column_aliases = []
915
916            if column_aliases:
917                # If the source's columns are aliased, their aliases shadow the corresponding column names.
918                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
919                columns = [
920                    alias or name
921                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
922                ]
923
924            self._get_source_columns_cache[cache_key] = columns
925
926        return self._get_source_columns_cache[cache_key]
927
928    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
929        if self._source_columns is None:
930            self._source_columns = {
931                source_name: self.get_source_columns(source_name)
932                for source_name, source in itertools.chain(
933                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
934                )
935            }
936        return self._source_columns
937
938    def _get_unambiguous_columns(
939        self, source_columns: t.Dict[str, t.Sequence[str]]
940    ) -> t.Mapping[str, str]:
941        """
942        Find all the unambiguous columns in sources.
943
944        Args:
945            source_columns: Mapping of names to source columns.
946
947        Returns:
948            Mapping of column name to source name.
949        """
950        if not source_columns:
951            return {}
952
953        source_columns_pairs = list(source_columns.items())
954
955        first_table, first_columns = source_columns_pairs[0]
956
957        if len(source_columns_pairs) == 1:
958            # Performance optimization - avoid copying first_columns if there is only one table.
959            return SingleValuedMapping(first_columns, first_table)
960
961        unambiguous_columns = {col: first_table for col in first_columns}
962        all_columns = set(unambiguous_columns)
963
964        for table, columns in source_columns_pairs[1:]:
965            unique = set(columns)
966            ambiguous = all_columns.intersection(unique)
967            all_columns.update(columns)
968
969            for column in ambiguous:
970                unambiguous_columns.pop(column, None)
971            for column in unique.difference(ambiguous):
972                unambiguous_columns[column] = table
973
974        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
824    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
825        self.scope = scope
826        self.schema = schema
827        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
828        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
829        self._all_columns: t.Optional[t.Set[str]] = None
830        self._infer_schema = infer_schema
831        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
833    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
834        """
835        Get the table for a column name.
836
837        Args:
838            column_name: The column name to find the table for.
839        Returns:
840            The table name if it can be found/inferred.
841        """
842        if self._unambiguous_columns is None:
843            self._unambiguous_columns = self._get_unambiguous_columns(
844                self._get_all_source_columns()
845            )
846
847        table_name = self._unambiguous_columns.get(column_name)
848
849        if not table_name and self._infer_schema:
850            sources_without_schema = tuple(
851                source
852                for source, columns in self._get_all_source_columns().items()
853                if not columns or "*" in columns
854            )
855            if len(sources_without_schema) == 1:
856                table_name = sources_without_schema[0]
857
858        if table_name not in self.scope.selected_sources:
859            return exp.to_identifier(table_name)
860
861        node, _ = self.scope.selected_sources.get(table_name)
862
863        if isinstance(node, exp.Query):
864            while node and node.alias != table_name:
865                node = node.parent
866
867        node_alias = node.args.get("alias")
868        if node_alias:
869            return exp.to_identifier(node_alias.this)
870
871        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
873    @property
874    def all_columns(self) -> t.Set[str]:
875        """All available columns of all sources in this scope"""
876        if self._all_columns is None:
877            self._all_columns = {
878                column for columns in self._get_all_source_columns().values() for column in columns
879            }
880        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
882    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
883        """Resolve the source columns for a given source `name`."""
884        cache_key = (name, only_visible)
885        if cache_key not in self._get_source_columns_cache:
886            if name not in self.scope.sources:
887                raise OptimizeError(f"Unknown table: {name}")
888
889            source = self.scope.sources[name]
890
891            if isinstance(source, exp.Table):
892                columns = self.schema.column_names(source, only_visible)
893            elif isinstance(source, Scope) and isinstance(
894                source.expression, (exp.Values, exp.Unnest)
895            ):
896                columns = source.expression.named_selects
897
898                # in bigquery, unnest structs are automatically scoped as tables, so you can
899                # directly select a struct field in a query.
900                # this handles the case where the unnest is statically defined.
901                if self.schema.dialect == "bigquery":
902                    if source.expression.is_type(exp.DataType.Type.STRUCT):
903                        for k in source.expression.type.expressions:  # type: ignore
904                            columns.append(k.name)
905            else:
906                columns = source.expression.named_selects
907
908            node, _ = self.scope.selected_sources.get(name) or (None, None)
909            if isinstance(node, Scope):
910                column_aliases = node.expression.alias_column_names
911            elif isinstance(node, exp.Expression):
912                column_aliases = node.alias_column_names
913            else:
914                column_aliases = []
915
916            if column_aliases:
917                # If the source's columns are aliased, their aliases shadow the corresponding column names.
918                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
919                columns = [
920                    alias or name
921                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
922                ]
923
924            self._get_source_columns_cache[cache_key] = columns
925
926        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.