Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25    allow_partial_qualification: bool = False,
 26) -> exp.Expression:
 27    """
 28    Rewrite sqlglot AST to have fully qualified columns.
 29
 30    Example:
 31        >>> import sqlglot
 32        >>> schema = {"tbl": {"col": "INT"}}
 33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 34        >>> qualify_columns(expression, schema).sql()
 35        'SELECT tbl.col AS col FROM tbl'
 36
 37    Args:
 38        expression: Expression to qualify.
 39        schema: Database schema.
 40        expand_alias_refs: Whether to expand references to aliases.
 41        expand_stars: Whether to expand star queries. This is a necessary step
 42            for most of the optimizer's rules to work; do not set to False unless you
 43            know what you're doing!
 44        infer_schema: Whether to infer the schema if missing.
 45        allow_partial_qualification: Whether to allow partial qualification.
 46
 47    Returns:
 48        The qualified expression.
 49
 50    Notes:
 51        - Currently only handles a single PIVOT or UNPIVOT operator
 52    """
 53    schema = ensure_schema(schema)
 54    annotator = TypeAnnotator(schema)
 55    infer_schema = schema.empty if infer_schema is None else infer_schema
 56    dialect = Dialect.get_or_raise(schema.dialect)
 57    pseudocolumns = dialect.PSEUDOCOLUMNS
 58
 59    for scope in traverse_scope(expression):
 60        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 61        _pop_table_column_aliases(scope.ctes)
 62        _pop_table_column_aliases(scope.derived_tables)
 63        using_column_tables = _expand_using(scope, resolver)
 64
 65        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 66            _expand_alias_refs(
 67                scope,
 68                resolver,
 69                dialect,
 70                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 71            )
 72
 73        _convert_columns_to_dots(scope, resolver)
 74        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 75
 76        if not schema.empty and expand_alias_refs:
 77            _expand_alias_refs(scope, resolver, dialect)
 78
 79        if isinstance(scope.expression, exp.Select):
 80            if expand_stars:
 81                _expand_stars(
 82                    scope,
 83                    resolver,
 84                    using_column_tables,
 85                    pseudocolumns,
 86                    annotator,
 87                )
 88            qualify_outputs(scope)
 89
 90        _expand_group_by(scope, dialect)
 91        _expand_order_by(scope, resolver)
 92
 93        if dialect == "bigquery":
 94            annotator.annotate_scope(scope)
 95
 96    return expression
 97
 98
 99def validate_qualify_columns(expression: E) -> E:
100    """Raise an `OptimizeError` if any columns aren't qualified"""
101    all_unqualified_columns = []
102    for scope in traverse_scope(expression):
103        if isinstance(scope.expression, exp.Select):
104            unqualified_columns = scope.unqualified_columns
105
106            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
107                column = scope.external_columns[0]
108                for_table = f" for table: '{column.table}'" if column.table else ""
109                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
110
111            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
112                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
113                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
114                # this list here to ensure those in the former category will be excluded.
115                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
116                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
117
118            all_unqualified_columns.extend(unqualified_columns)
119
120    if all_unqualified_columns:
121        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
122
123    return expression
124
125
126def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
127    name_column = []
128    field = unpivot.args.get("field")
129    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
130        name_column.append(field.this)
131
132    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
133    return itertools.chain(name_column, value_columns)
134
135
136def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
137    """
138    Remove table column aliases.
139
140    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
141    """
142    for derived_table in derived_tables:
143        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
144            continue
145        table_alias = derived_table.args.get("alias")
146        if table_alias:
147            table_alias.args.pop("columns", None)
148
149
150def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
151    columns = {}
152
153    def _update_source_columns(source_name: str) -> None:
154        for column_name in resolver.get_source_columns(source_name):
155            if column_name not in columns:
156                columns[column_name] = source_name
157
158    joins = list(scope.find_all(exp.Join))
159    names = {join.alias_or_name for join in joins}
160    ordered = [key for key in scope.selected_sources if key not in names]
161
162    # Mapping of automatically joined column names to an ordered set of source names (dict).
163    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
164
165    for source_name in ordered:
166        _update_source_columns(source_name)
167
168    for i, join in enumerate(joins):
169        source_table = ordered[-1]
170        if source_table:
171            _update_source_columns(source_table)
172
173        join_table = join.alias_or_name
174        ordered.append(join_table)
175
176        using = join.args.get("using")
177        if not using:
178            continue
179
180        join_columns = resolver.get_source_columns(join_table)
181        conditions = []
182        using_identifier_count = len(using)
183
184        for identifier in using:
185            identifier = identifier.name
186            table = columns.get(identifier)
187
188            if not table or identifier not in join_columns:
189                if (columns and "*" not in columns) and join_columns:
190                    raise OptimizeError(f"Cannot automatically join: {identifier}")
191
192            table = table or source_table
193
194            if i == 0 or using_identifier_count == 1:
195                lhs: exp.Expression = exp.column(identifier, table=table)
196            else:
197                coalesce_columns = [
198                    exp.column(identifier, table=t)
199                    for t in ordered[:-1]
200                    if identifier in resolver.get_source_columns(t)
201                ]
202                if len(coalesce_columns) > 1:
203                    lhs = exp.func("coalesce", *coalesce_columns)
204                else:
205                    lhs = exp.column(identifier, table=table)
206
207            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
208
209            # Set all values in the dict to None, because we only care about the key ordering
210            tables = column_tables.setdefault(identifier, {})
211            if table not in tables:
212                tables[table] = None
213            if join_table not in tables:
214                tables[join_table] = None
215
216        join.args.pop("using")
217        join.set("on", exp.and_(*conditions, copy=False))
218
219    if column_tables:
220        for column in scope.columns:
221            if not column.table and column.name in column_tables:
222                tables = column_tables[column.name]
223                coalesce_args = [exp.column(column.name, table=table) for table in tables]
224                replacement: exp.Expression = exp.func("coalesce", *coalesce_args)
225
226                if isinstance(column.parent, exp.Select):
227                    # Ensure the USING column keeps its name if it's projected
228                    replacement = alias(replacement, alias=column.name, copy=False)
229                elif isinstance(column.parent, exp.Struct):
230                    # Ensure the USING column keeps its name if it's an anonymous STRUCT field
231                    replacement = exp.PropertyEQ(
232                        this=exp.to_identifier(column.name), expression=replacement
233                    )
234
235                scope.replace(column, replacement)
236
237    return column_tables
238
239
240def _expand_alias_refs(
241    scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
242) -> None:
243    """
244    Expand references to aliases.
245    Example:
246        SELECT y.foo AS bar, bar * 2 AS baz FROM y
247     => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
248    """
249    expression = scope.expression
250
251    if not isinstance(expression, exp.Select):
252        return
253
254    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
255
256    def replace_columns(
257        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
258    ) -> None:
259        is_group_by = isinstance(node, exp.Group)
260        if not node or (expand_only_groupby and not is_group_by):
261            return
262
263        for column in walk_in_scope(node, prune=lambda node: node.is_star):
264            if not isinstance(column, exp.Column):
265                continue
266
267            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
268            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
269            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
270            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
271            if expand_only_groupby and is_group_by and column.parent is not node:
272                continue
273
274            table = resolver.get_table(column.name) if resolve_table and not column.table else None
275            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
276            double_agg = (
277                (
278                    alias_expr.find(exp.AggFunc)
279                    and (
280                        column.find_ancestor(exp.AggFunc)
281                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
282                    )
283                )
284                if alias_expr
285                else False
286            )
287
288            if table and (not alias_expr or double_agg):
289                column.set("table", table)
290            elif not column.table and alias_expr and not double_agg:
291                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
292                    if literal_index:
293                        column.replace(exp.Literal.number(i))
294                else:
295                    column = column.replace(exp.paren(alias_expr))
296                    simplified = simplify_parens(column)
297                    if simplified is not column:
298                        column.replace(simplified)
299
300    for i, projection in enumerate(expression.selects):
301        replace_columns(projection)
302        if isinstance(projection, exp.Alias):
303            alias_to_expression[projection.alias] = (projection.this, i + 1)
304
305    parent_scope = scope
306    while parent_scope.is_union:
307        parent_scope = parent_scope.parent
308
309    # We shouldn't expand aliases if they match the recursive CTE's columns
310    if parent_scope.is_cte:
311        cte = parent_scope.expression.parent
312        if cte.find_ancestor(exp.With).recursive:
313            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
314                alias_to_expression.pop(recursive_cte_column.output_name, None)
315
316    replace_columns(expression.args.get("where"))
317    replace_columns(expression.args.get("group"), literal_index=True)
318    replace_columns(expression.args.get("having"), resolve_table=True)
319    replace_columns(expression.args.get("qualify"), resolve_table=True)
320
321    # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
322    # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
323    if dialect == "snowflake":
324        for join in expression.args.get("joins") or []:
325            replace_columns(join)
326
327    scope.clear_cache()
328
329
330def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
331    expression = scope.expression
332    group = expression.args.get("group")
333    if not group:
334        return
335
336    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
337    expression.set("group", group)
338
339
340def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
341    order = scope.expression.args.get("order")
342    if not order:
343        return
344
345    ordereds = order.expressions
346    for ordered, new_expression in zip(
347        ordereds,
348        _expand_positional_references(
349            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
350        ),
351    ):
352        for agg in ordered.find_all(exp.AggFunc):
353            for col in agg.find_all(exp.Column):
354                if not col.table:
355                    col.set("table", resolver.get_table(col.name))
356
357        ordered.set("this", new_expression)
358
359    if scope.expression.args.get("group"):
360        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
361
362        for ordered in ordereds:
363            ordered = ordered.this
364
365            ordered.replace(
366                exp.to_identifier(_select_by_pos(scope, ordered).alias)
367                if ordered.is_int
368                else selects.get(ordered, ordered)
369            )
370
371
372def _expand_positional_references(
373    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
374) -> t.List[exp.Expression]:
375    new_nodes: t.List[exp.Expression] = []
376    ambiguous_projections = None
377
378    for node in expressions:
379        if node.is_int:
380            select = _select_by_pos(scope, t.cast(exp.Literal, node))
381
382            if alias:
383                new_nodes.append(exp.column(select.args["alias"].copy()))
384            else:
385                select = select.this
386
387                if dialect == "bigquery":
388                    if ambiguous_projections is None:
389                        # When a projection name is also a source name and it is referenced in the
390                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
391                        ambiguous_projections = {
392                            s.alias_or_name
393                            for s in scope.expression.selects
394                            if s.alias_or_name in scope.selected_sources
395                        }
396
397                    ambiguous = any(
398                        column.parts[0].name in ambiguous_projections
399                        for column in select.find_all(exp.Column)
400                    )
401                else:
402                    ambiguous = False
403
404                if (
405                    isinstance(select, exp.CONSTANTS)
406                    or select.find(exp.Explode, exp.Unnest)
407                    or ambiguous
408                ):
409                    new_nodes.append(node)
410                else:
411                    new_nodes.append(select.copy())
412        else:
413            new_nodes.append(node)
414
415    return new_nodes
416
417
418def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
419    try:
420        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
421    except IndexError:
422        raise OptimizeError(f"Unknown output column: {node.name}")
423
424
425def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
426    """
427    Converts `Column` instances that represent struct field lookup into chained `Dots`.
428
429    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
430    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
431    """
432    converted = False
433    for column in itertools.chain(scope.columns, scope.stars):
434        if isinstance(column, exp.Dot):
435            continue
436
437        column_table: t.Optional[str | exp.Identifier] = column.table
438        if (
439            column_table
440            and column_table not in scope.sources
441            and (
442                not scope.parent
443                or column_table not in scope.parent.sources
444                or not scope.is_correlated_subquery
445            )
446        ):
447            root, *parts = column.parts
448
449            if root.name in scope.sources:
450                # The struct is already qualified, but we still need to change the AST
451                column_table = root
452                root, *parts = parts
453            else:
454                column_table = resolver.get_table(root.name)
455
456            if column_table:
457                converted = True
458                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
459
460    if converted:
461        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
462        # a `for column in scope.columns` iteration, even though they shouldn't be
463        scope.clear_cache()
464
465
466def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
467    """Disambiguate columns, ensuring each column specifies a source"""
468    for column in scope.columns:
469        column_table = column.table
470        column_name = column.name
471
472        if column_table and column_table in scope.sources:
473            source_columns = resolver.get_source_columns(column_table)
474            if (
475                not allow_partial_qualification
476                and source_columns
477                and column_name not in source_columns
478                and "*" not in source_columns
479            ):
480                raise OptimizeError(f"Unknown column: {column_name}")
481
482        if not column_table:
483            if scope.pivots and not column.find_ancestor(exp.Pivot):
484                # If the column is under the Pivot expression, we need to qualify it
485                # using the name of the pivoted source instead of the pivot's alias
486                column.set("table", exp.to_identifier(scope.pivots[0].alias))
487                continue
488
489            # column_table can be a '' because bigquery unnest has no table alias
490            column_table = resolver.get_table(column_name)
491            if column_table:
492                column.set("table", column_table)
493
494    for pivot in scope.pivots:
495        for column in pivot.find_all(exp.Column):
496            if not column.table and column.name in resolver.all_columns:
497                column_table = resolver.get_table(column.name)
498                if column_table:
499                    column.set("table", column_table)
500
501
502def _expand_struct_stars(
503    expression: exp.Dot,
504) -> t.List[exp.Alias]:
505    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
506
507    dot_column = t.cast(exp.Column, expression.find(exp.Column))
508    if not dot_column.is_type(exp.DataType.Type.STRUCT):
509        return []
510
511    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
512    dot_column = dot_column.copy()
513    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
514
515    # First part is the table name and last part is the star so they can be dropped
516    dot_parts = expression.parts[1:-1]
517
518    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
519    for part in dot_parts[1:]:
520        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
521            # Unable to expand star unless all fields are named
522            if not isinstance(field.this, exp.Identifier):
523                return []
524
525            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
526                starting_struct = field
527                break
528        else:
529            # There is no matching field in the struct
530            return []
531
532    taken_names = set()
533    new_selections = []
534
535    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
536        name = field.name
537
538        # Ambiguous or anonymous fields can't be expanded
539        if name in taken_names or not isinstance(field.this, exp.Identifier):
540            return []
541
542        taken_names.add(name)
543
544        this = field.this.copy()
545        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
546        new_column = exp.column(
547            t.cast(exp.Identifier, root),
548            table=dot_column.args.get("table"),
549            fields=t.cast(t.List[exp.Identifier], parts),
550        )
551        new_selections.append(alias(new_column, this, copy=False))
552
553    return new_selections
554
555
556def _expand_stars(
557    scope: Scope,
558    resolver: Resolver,
559    using_column_tables: t.Dict[str, t.Any],
560    pseudocolumns: t.Set[str],
561    annotator: TypeAnnotator,
562) -> None:
563    """Expand stars to lists of column selections"""
564
565    new_selections: t.List[exp.Expression] = []
566    except_columns: t.Dict[int, t.Set[str]] = {}
567    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
568    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
569
570    coalesced_columns = set()
571    dialect = resolver.schema.dialect
572
573    pivot_output_columns = None
574    pivot_exclude_columns = None
575
576    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
577    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
578        if pivot.unpivot:
579            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
580
581            field = pivot.args.get("field")
582            if isinstance(field, exp.In):
583                pivot_exclude_columns = {
584                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
585                }
586        else:
587            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
588
589            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
590            if not pivot_output_columns:
591                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
592
593    is_bigquery = dialect == "bigquery"
594    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
595        # Found struct expansion, annotate scope ahead of time
596        annotator.annotate_scope(scope)
597
598    for expression in scope.expression.selects:
599        tables = []
600        if isinstance(expression, exp.Star):
601            tables.extend(scope.selected_sources)
602            _add_except_columns(expression, tables, except_columns)
603            _add_replace_columns(expression, tables, replace_columns)
604            _add_rename_columns(expression, tables, rename_columns)
605        elif expression.is_star:
606            if not isinstance(expression, exp.Dot):
607                tables.append(expression.table)
608                _add_except_columns(expression.this, tables, except_columns)
609                _add_replace_columns(expression.this, tables, replace_columns)
610                _add_rename_columns(expression.this, tables, rename_columns)
611            elif is_bigquery:
612                struct_fields = _expand_struct_stars(expression)
613                if struct_fields:
614                    new_selections.extend(struct_fields)
615                    continue
616
617        if not tables:
618            new_selections.append(expression)
619            continue
620
621        for table in tables:
622            if table not in scope.sources:
623                raise OptimizeError(f"Unknown table: {table}")
624
625            columns = resolver.get_source_columns(table, only_visible=True)
626            columns = columns or scope.outer_columns
627
628            if pseudocolumns:
629                columns = [name for name in columns if name.upper() not in pseudocolumns]
630
631            if not columns or "*" in columns:
632                return
633
634            table_id = id(table)
635            columns_to_exclude = except_columns.get(table_id) or set()
636            renamed_columns = rename_columns.get(table_id, {})
637            replaced_columns = replace_columns.get(table_id, {})
638
639            if pivot:
640                if pivot_output_columns and pivot_exclude_columns:
641                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
642                    pivot_columns.extend(pivot_output_columns)
643                else:
644                    pivot_columns = pivot.alias_column_names
645
646                if pivot_columns:
647                    new_selections.extend(
648                        alias(exp.column(name, table=pivot.alias), name, copy=False)
649                        for name in pivot_columns
650                        if name not in columns_to_exclude
651                    )
652                    continue
653
654            for name in columns:
655                if name in columns_to_exclude or name in coalesced_columns:
656                    continue
657                if name in using_column_tables and table in using_column_tables[name]:
658                    coalesced_columns.add(name)
659                    tables = using_column_tables[name]
660                    coalesce_args = [exp.column(name, table=table) for table in tables]
661
662                    new_selections.append(
663                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
664                    )
665                else:
666                    alias_ = renamed_columns.get(name, name)
667                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
668                    new_selections.append(
669                        alias(selection_expr, alias_, copy=False)
670                        if alias_ != name
671                        else selection_expr
672                    )
673
674    # Ensures we don't overwrite the initial selections with an empty list
675    if new_selections and isinstance(scope.expression, exp.Select):
676        scope.expression.set("expressions", new_selections)
677
678
679def _add_except_columns(
680    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
681) -> None:
682    except_ = expression.args.get("except")
683
684    if not except_:
685        return
686
687    columns = {e.name for e in except_}
688
689    for table in tables:
690        except_columns[id(table)] = columns
691
692
693def _add_rename_columns(
694    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
695) -> None:
696    rename = expression.args.get("rename")
697
698    if not rename:
699        return
700
701    columns = {e.this.name: e.alias for e in rename}
702
703    for table in tables:
704        rename_columns[id(table)] = columns
705
706
707def _add_replace_columns(
708    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
709) -> None:
710    replace = expression.args.get("replace")
711
712    if not replace:
713        return
714
715    columns = {e.alias: e for e in replace}
716
717    for table in tables:
718        replace_columns[id(table)] = columns
719
720
721def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
722    """Ensure all output columns are aliased"""
723    if isinstance(scope_or_expression, exp.Expression):
724        scope = build_scope(scope_or_expression)
725        if not isinstance(scope, Scope):
726            return
727    else:
728        scope = scope_or_expression
729
730    new_selections = []
731    for i, (selection, aliased_column) in enumerate(
732        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
733    ):
734        if selection is None:
735            break
736
737        if isinstance(selection, exp.Subquery):
738            if not selection.output_name:
739                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
740        elif not isinstance(selection, exp.Alias) and not selection.is_star:
741            selection = alias(
742                selection,
743                alias=selection.output_name or f"_col_{i}",
744                copy=False,
745            )
746        if aliased_column:
747            selection.set("alias", exp.to_identifier(aliased_column))
748
749        new_selections.append(selection)
750
751    if isinstance(scope.expression, exp.Select):
752        scope.expression.set("expressions", new_selections)
753
754
755def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
756    """Makes sure all identifiers that need to be quoted are quoted."""
757    return expression.transform(
758        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
759    )  # type: ignore
760
761
762def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
763    """
764    Pushes down the CTE alias columns into the projection,
765
766    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
767
768    Example:
769        >>> import sqlglot
770        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
771        >>> pushdown_cte_alias_columns(expression).sql()
772        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
773
774    Args:
775        expression: Expression to pushdown.
776
777    Returns:
778        The expression with the CTE aliases pushed down into the projection.
779    """
780    for cte in expression.find_all(exp.CTE):
781        if cte.alias_column_names:
782            new_expressions = []
783            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
784                if isinstance(projection, exp.Alias):
785                    projection.set("alias", _alias)
786                else:
787                    projection = alias(projection, alias=_alias)
788                new_expressions.append(projection)
789            cte.this.set("expressions", new_expressions)
790
791    return expression
792
793
794class Resolver:
795    """
796    Helper for resolving columns.
797
798    This is a class so we can lazily load some things and easily share them across functions.
799    """
800
801    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
802        self.scope = scope
803        self.schema = schema
804        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
805        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
806        self._all_columns: t.Optional[t.Set[str]] = None
807        self._infer_schema = infer_schema
808        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
809
810    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
811        """
812        Get the table for a column name.
813
814        Args:
815            column_name: The column name to find the table for.
816        Returns:
817            The table name if it can be found/inferred.
818        """
819        if self._unambiguous_columns is None:
820            self._unambiguous_columns = self._get_unambiguous_columns(
821                self._get_all_source_columns()
822            )
823
824        table_name = self._unambiguous_columns.get(column_name)
825
826        if not table_name and self._infer_schema:
827            sources_without_schema = tuple(
828                source
829                for source, columns in self._get_all_source_columns().items()
830                if not columns or "*" in columns
831            )
832            if len(sources_without_schema) == 1:
833                table_name = sources_without_schema[0]
834
835        if table_name not in self.scope.selected_sources:
836            return exp.to_identifier(table_name)
837
838        node, _ = self.scope.selected_sources.get(table_name)
839
840        if isinstance(node, exp.Query):
841            while node and node.alias != table_name:
842                node = node.parent
843
844        node_alias = node.args.get("alias")
845        if node_alias:
846            return exp.to_identifier(node_alias.this)
847
848        return exp.to_identifier(table_name)
849
850    @property
851    def all_columns(self) -> t.Set[str]:
852        """All available columns of all sources in this scope"""
853        if self._all_columns is None:
854            self._all_columns = {
855                column for columns in self._get_all_source_columns().values() for column in columns
856            }
857        return self._all_columns
858
859    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
860        """Resolve the source columns for a given source `name`."""
861        cache_key = (name, only_visible)
862        if cache_key not in self._get_source_columns_cache:
863            if name not in self.scope.sources:
864                raise OptimizeError(f"Unknown table: {name}")
865
866            source = self.scope.sources[name]
867
868            if isinstance(source, exp.Table):
869                columns = self.schema.column_names(source, only_visible)
870            elif isinstance(source, Scope) and isinstance(
871                source.expression, (exp.Values, exp.Unnest)
872            ):
873                columns = source.expression.named_selects
874
875                # in bigquery, unnest structs are automatically scoped as tables, so you can
876                # directly select a struct field in a query.
877                # this handles the case where the unnest is statically defined.
878                if self.schema.dialect == "bigquery":
879                    if source.expression.is_type(exp.DataType.Type.STRUCT):
880                        for k in source.expression.type.expressions:  # type: ignore
881                            columns.append(k.name)
882            else:
883                columns = source.expression.named_selects
884
885            node, _ = self.scope.selected_sources.get(name) or (None, None)
886            if isinstance(node, Scope):
887                column_aliases = node.expression.alias_column_names
888            elif isinstance(node, exp.Expression):
889                column_aliases = node.alias_column_names
890            else:
891                column_aliases = []
892
893            if column_aliases:
894                # If the source's columns are aliased, their aliases shadow the corresponding column names.
895                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
896                columns = [
897                    alias or name
898                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
899                ]
900
901            pseudocolumns = self._get_source_pseudocolumns(name)
902            if pseudocolumns:
903                columns = list(columns)
904                columns.extend(c for c in pseudocolumns if c not in columns)
905
906            self._get_source_columns_cache[cache_key] = columns
907
908        return self._get_source_columns_cache[cache_key]
909
910    def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]:
911        if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"):
912            # When there is a CONNECT BY clause, there is only one table being scanned
913            # See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by
914            return ["LEVEL"]
915        return []
916
917    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
918        if self._source_columns is None:
919            self._source_columns = {
920                source_name: self.get_source_columns(source_name)
921                for source_name, source in itertools.chain(
922                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
923                )
924            }
925        return self._source_columns
926
927    def _get_unambiguous_columns(
928        self, source_columns: t.Dict[str, t.Sequence[str]]
929    ) -> t.Mapping[str, str]:
930        """
931        Find all the unambiguous columns in sources.
932
933        Args:
934            source_columns: Mapping of names to source columns.
935
936        Returns:
937            Mapping of column name to source name.
938        """
939        if not source_columns:
940            return {}
941
942        source_columns_pairs = list(source_columns.items())
943
944        first_table, first_columns = source_columns_pairs[0]
945
946        if len(source_columns_pairs) == 1:
947            # Performance optimization - avoid copying first_columns if there is only one table.
948            return SingleValuedMapping(first_columns, first_table)
949
950        unambiguous_columns = {col: first_table for col in first_columns}
951        all_columns = set(unambiguous_columns)
952
953        for table, columns in source_columns_pairs[1:]:
954            unique = set(columns)
955            ambiguous = all_columns.intersection(unique)
956            all_columns.update(columns)
957
958            for column in ambiguous:
959                unambiguous_columns.pop(column, None)
960            for column in unique.difference(ambiguous):
961                unambiguous_columns[column] = table
962
963        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26    allow_partial_qualification: bool = False,
27) -> exp.Expression:
28    """
29    Rewrite sqlglot AST to have fully qualified columns.
30
31    Example:
32        >>> import sqlglot
33        >>> schema = {"tbl": {"col": "INT"}}
34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
35        >>> qualify_columns(expression, schema).sql()
36        'SELECT tbl.col AS col FROM tbl'
37
38    Args:
39        expression: Expression to qualify.
40        schema: Database schema.
41        expand_alias_refs: Whether to expand references to aliases.
42        expand_stars: Whether to expand star queries. This is a necessary step
43            for most of the optimizer's rules to work; do not set to False unless you
44            know what you're doing!
45        infer_schema: Whether to infer the schema if missing.
46        allow_partial_qualification: Whether to allow partial qualification.
47
48    Returns:
49        The qualified expression.
50
51    Notes:
52        - Currently only handles a single PIVOT or UNPIVOT operator
53    """
54    schema = ensure_schema(schema)
55    annotator = TypeAnnotator(schema)
56    infer_schema = schema.empty if infer_schema is None else infer_schema
57    dialect = Dialect.get_or_raise(schema.dialect)
58    pseudocolumns = dialect.PSEUDOCOLUMNS
59
60    for scope in traverse_scope(expression):
61        resolver = Resolver(scope, schema, infer_schema=infer_schema)
62        _pop_table_column_aliases(scope.ctes)
63        _pop_table_column_aliases(scope.derived_tables)
64        using_column_tables = _expand_using(scope, resolver)
65
66        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
67            _expand_alias_refs(
68                scope,
69                resolver,
70                dialect,
71                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
72            )
73
74        _convert_columns_to_dots(scope, resolver)
75        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
76
77        if not schema.empty and expand_alias_refs:
78            _expand_alias_refs(scope, resolver, dialect)
79
80        if isinstance(scope.expression, exp.Select):
81            if expand_stars:
82                _expand_stars(
83                    scope,
84                    resolver,
85                    using_column_tables,
86                    pseudocolumns,
87                    annotator,
88                )
89            qualify_outputs(scope)
90
91        _expand_group_by(scope, dialect)
92        _expand_order_by(scope, resolver)
93
94        if dialect == "bigquery":
95            annotator.annotate_scope(scope)
96
97    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
100def validate_qualify_columns(expression: E) -> E:
101    """Raise an `OptimizeError` if any columns aren't qualified"""
102    all_unqualified_columns = []
103    for scope in traverse_scope(expression):
104        if isinstance(scope.expression, exp.Select):
105            unqualified_columns = scope.unqualified_columns
106
107            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
108                column = scope.external_columns[0]
109                for_table = f" for table: '{column.table}'" if column.table else ""
110                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
111
112            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
113                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
114                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
115                # this list here to ensure those in the former category will be excluded.
116                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
117                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
118
119            all_unqualified_columns.extend(unqualified_columns)
120
121    if all_unqualified_columns:
122        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
123
124    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
722def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
723    """Ensure all output columns are aliased"""
724    if isinstance(scope_or_expression, exp.Expression):
725        scope = build_scope(scope_or_expression)
726        if not isinstance(scope, Scope):
727            return
728    else:
729        scope = scope_or_expression
730
731    new_selections = []
732    for i, (selection, aliased_column) in enumerate(
733        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
734    ):
735        if selection is None:
736            break
737
738        if isinstance(selection, exp.Subquery):
739            if not selection.output_name:
740                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
741        elif not isinstance(selection, exp.Alias) and not selection.is_star:
742            selection = alias(
743                selection,
744                alias=selection.output_name or f"_col_{i}",
745                copy=False,
746            )
747        if aliased_column:
748            selection.set("alias", exp.to_identifier(aliased_column))
749
750        new_selections.append(selection)
751
752    if isinstance(scope.expression, exp.Select):
753        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
756def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
757    """Makes sure all identifiers that need to be quoted are quoted."""
758    return expression.transform(
759        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
760    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
763def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
764    """
765    Pushes down the CTE alias columns into the projection,
766
767    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
768
769    Example:
770        >>> import sqlglot
771        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
772        >>> pushdown_cte_alias_columns(expression).sql()
773        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
774
775    Args:
776        expression: Expression to pushdown.
777
778    Returns:
779        The expression with the CTE aliases pushed down into the projection.
780    """
781    for cte in expression.find_all(exp.CTE):
782        if cte.alias_column_names:
783            new_expressions = []
784            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
785                if isinstance(projection, exp.Alias):
786                    projection.set("alias", _alias)
787                else:
788                    projection = alias(projection, alias=_alias)
789                new_expressions.append(projection)
790            cte.this.set("expressions", new_expressions)
791
792    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
795class Resolver:
796    """
797    Helper for resolving columns.
798
799    This is a class so we can lazily load some things and easily share them across functions.
800    """
801
802    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
803        self.scope = scope
804        self.schema = schema
805        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
806        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
807        self._all_columns: t.Optional[t.Set[str]] = None
808        self._infer_schema = infer_schema
809        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
810
811    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
812        """
813        Get the table for a column name.
814
815        Args:
816            column_name: The column name to find the table for.
817        Returns:
818            The table name if it can be found/inferred.
819        """
820        if self._unambiguous_columns is None:
821            self._unambiguous_columns = self._get_unambiguous_columns(
822                self._get_all_source_columns()
823            )
824
825        table_name = self._unambiguous_columns.get(column_name)
826
827        if not table_name and self._infer_schema:
828            sources_without_schema = tuple(
829                source
830                for source, columns in self._get_all_source_columns().items()
831                if not columns or "*" in columns
832            )
833            if len(sources_without_schema) == 1:
834                table_name = sources_without_schema[0]
835
836        if table_name not in self.scope.selected_sources:
837            return exp.to_identifier(table_name)
838
839        node, _ = self.scope.selected_sources.get(table_name)
840
841        if isinstance(node, exp.Query):
842            while node and node.alias != table_name:
843                node = node.parent
844
845        node_alias = node.args.get("alias")
846        if node_alias:
847            return exp.to_identifier(node_alias.this)
848
849        return exp.to_identifier(table_name)
850
851    @property
852    def all_columns(self) -> t.Set[str]:
853        """All available columns of all sources in this scope"""
854        if self._all_columns is None:
855            self._all_columns = {
856                column for columns in self._get_all_source_columns().values() for column in columns
857            }
858        return self._all_columns
859
860    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
861        """Resolve the source columns for a given source `name`."""
862        cache_key = (name, only_visible)
863        if cache_key not in self._get_source_columns_cache:
864            if name not in self.scope.sources:
865                raise OptimizeError(f"Unknown table: {name}")
866
867            source = self.scope.sources[name]
868
869            if isinstance(source, exp.Table):
870                columns = self.schema.column_names(source, only_visible)
871            elif isinstance(source, Scope) and isinstance(
872                source.expression, (exp.Values, exp.Unnest)
873            ):
874                columns = source.expression.named_selects
875
876                # in bigquery, unnest structs are automatically scoped as tables, so you can
877                # directly select a struct field in a query.
878                # this handles the case where the unnest is statically defined.
879                if self.schema.dialect == "bigquery":
880                    if source.expression.is_type(exp.DataType.Type.STRUCT):
881                        for k in source.expression.type.expressions:  # type: ignore
882                            columns.append(k.name)
883            else:
884                columns = source.expression.named_selects
885
886            node, _ = self.scope.selected_sources.get(name) or (None, None)
887            if isinstance(node, Scope):
888                column_aliases = node.expression.alias_column_names
889            elif isinstance(node, exp.Expression):
890                column_aliases = node.alias_column_names
891            else:
892                column_aliases = []
893
894            if column_aliases:
895                # If the source's columns are aliased, their aliases shadow the corresponding column names.
896                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
897                columns = [
898                    alias or name
899                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
900                ]
901
902            pseudocolumns = self._get_source_pseudocolumns(name)
903            if pseudocolumns:
904                columns = list(columns)
905                columns.extend(c for c in pseudocolumns if c not in columns)
906
907            self._get_source_columns_cache[cache_key] = columns
908
909        return self._get_source_columns_cache[cache_key]
910
911    def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]:
912        if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"):
913            # When there is a CONNECT BY clause, there is only one table being scanned
914            # See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by
915            return ["LEVEL"]
916        return []
917
918    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
919        if self._source_columns is None:
920            self._source_columns = {
921                source_name: self.get_source_columns(source_name)
922                for source_name, source in itertools.chain(
923                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
924                )
925            }
926        return self._source_columns
927
928    def _get_unambiguous_columns(
929        self, source_columns: t.Dict[str, t.Sequence[str]]
930    ) -> t.Mapping[str, str]:
931        """
932        Find all the unambiguous columns in sources.
933
934        Args:
935            source_columns: Mapping of names to source columns.
936
937        Returns:
938            Mapping of column name to source name.
939        """
940        if not source_columns:
941            return {}
942
943        source_columns_pairs = list(source_columns.items())
944
945        first_table, first_columns = source_columns_pairs[0]
946
947        if len(source_columns_pairs) == 1:
948            # Performance optimization - avoid copying first_columns if there is only one table.
949            return SingleValuedMapping(first_columns, first_table)
950
951        unambiguous_columns = {col: first_table for col in first_columns}
952        all_columns = set(unambiguous_columns)
953
954        for table, columns in source_columns_pairs[1:]:
955            unique = set(columns)
956            ambiguous = all_columns.intersection(unique)
957            all_columns.update(columns)
958
959            for column in ambiguous:
960                unambiguous_columns.pop(column, None)
961            for column in unique.difference(ambiguous):
962                unambiguous_columns[column] = table
963
964        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
802    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
803        self.scope = scope
804        self.schema = schema
805        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
806        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
807        self._all_columns: t.Optional[t.Set[str]] = None
808        self._infer_schema = infer_schema
809        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
811    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
812        """
813        Get the table for a column name.
814
815        Args:
816            column_name: The column name to find the table for.
817        Returns:
818            The table name if it can be found/inferred.
819        """
820        if self._unambiguous_columns is None:
821            self._unambiguous_columns = self._get_unambiguous_columns(
822                self._get_all_source_columns()
823            )
824
825        table_name = self._unambiguous_columns.get(column_name)
826
827        if not table_name and self._infer_schema:
828            sources_without_schema = tuple(
829                source
830                for source, columns in self._get_all_source_columns().items()
831                if not columns or "*" in columns
832            )
833            if len(sources_without_schema) == 1:
834                table_name = sources_without_schema[0]
835
836        if table_name not in self.scope.selected_sources:
837            return exp.to_identifier(table_name)
838
839        node, _ = self.scope.selected_sources.get(table_name)
840
841        if isinstance(node, exp.Query):
842            while node and node.alias != table_name:
843                node = node.parent
844
845        node_alias = node.args.get("alias")
846        if node_alias:
847            return exp.to_identifier(node_alias.this)
848
849        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
851    @property
852    def all_columns(self) -> t.Set[str]:
853        """All available columns of all sources in this scope"""
854        if self._all_columns is None:
855            self._all_columns = {
856                column for columns in self._get_all_source_columns().values() for column in columns
857            }
858        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
860    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
861        """Resolve the source columns for a given source `name`."""
862        cache_key = (name, only_visible)
863        if cache_key not in self._get_source_columns_cache:
864            if name not in self.scope.sources:
865                raise OptimizeError(f"Unknown table: {name}")
866
867            source = self.scope.sources[name]
868
869            if isinstance(source, exp.Table):
870                columns = self.schema.column_names(source, only_visible)
871            elif isinstance(source, Scope) and isinstance(
872                source.expression, (exp.Values, exp.Unnest)
873            ):
874                columns = source.expression.named_selects
875
876                # in bigquery, unnest structs are automatically scoped as tables, so you can
877                # directly select a struct field in a query.
878                # this handles the case where the unnest is statically defined.
879                if self.schema.dialect == "bigquery":
880                    if source.expression.is_type(exp.DataType.Type.STRUCT):
881                        for k in source.expression.type.expressions:  # type: ignore
882                            columns.append(k.name)
883            else:
884                columns = source.expression.named_selects
885
886            node, _ = self.scope.selected_sources.get(name) or (None, None)
887            if isinstance(node, Scope):
888                column_aliases = node.expression.alias_column_names
889            elif isinstance(node, exp.Expression):
890                column_aliases = node.alias_column_names
891            else:
892                column_aliases = []
893
894            if column_aliases:
895                # If the source's columns are aliased, their aliases shadow the corresponding column names.
896                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
897                columns = [
898                    alias or name
899                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
900                ]
901
902            pseudocolumns = self._get_source_pseudocolumns(name)
903            if pseudocolumns:
904                columns = list(columns)
905                columns.extend(c for c in pseudocolumns if c not in columns)
906
907            self._get_source_columns_cache[cache_key] = columns
908
909        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.