Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25    allow_partial_qualification: bool = False,
 26) -> exp.Expression:
 27    """
 28    Rewrite sqlglot AST to have fully qualified columns.
 29
 30    Example:
 31        >>> import sqlglot
 32        >>> schema = {"tbl": {"col": "INT"}}
 33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 34        >>> qualify_columns(expression, schema).sql()
 35        'SELECT tbl.col AS col FROM tbl'
 36
 37    Args:
 38        expression: Expression to qualify.
 39        schema: Database schema.
 40        expand_alias_refs: Whether to expand references to aliases.
 41        expand_stars: Whether to expand star queries. This is a necessary step
 42            for most of the optimizer's rules to work; do not set to False unless you
 43            know what you're doing!
 44        infer_schema: Whether to infer the schema if missing.
 45        allow_partial_qualification: Whether to allow partial qualification.
 46
 47    Returns:
 48        The qualified expression.
 49
 50    Notes:
 51        - Currently only handles a single PIVOT or UNPIVOT operator
 52    """
 53    schema = ensure_schema(schema)
 54    annotator = TypeAnnotator(schema)
 55    infer_schema = schema.empty if infer_schema is None else infer_schema
 56    dialect = Dialect.get_or_raise(schema.dialect)
 57    pseudocolumns = dialect.PSEUDOCOLUMNS
 58
 59    for scope in traverse_scope(expression):
 60        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 61        _pop_table_column_aliases(scope.ctes)
 62        _pop_table_column_aliases(scope.derived_tables)
 63        using_column_tables = _expand_using(scope, resolver)
 64
 65        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 66            _expand_alias_refs(
 67                scope,
 68                resolver,
 69                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 70            )
 71
 72        _convert_columns_to_dots(scope, resolver)
 73        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 74
 75        if not schema.empty and expand_alias_refs:
 76            _expand_alias_refs(scope, resolver)
 77
 78        if not isinstance(scope.expression, exp.UDTF):
 79            if expand_stars:
 80                _expand_stars(
 81                    scope,
 82                    resolver,
 83                    using_column_tables,
 84                    pseudocolumns,
 85                    annotator,
 86                )
 87            qualify_outputs(scope)
 88
 89        _expand_group_by(scope, dialect)
 90        _expand_order_by(scope, resolver)
 91
 92        if dialect == "bigquery":
 93            annotator.annotate_scope(scope)
 94
 95    return expression
 96
 97
 98def validate_qualify_columns(expression: E) -> E:
 99    """Raise an `OptimizeError` if any columns aren't qualified"""
100    all_unqualified_columns = []
101    for scope in traverse_scope(expression):
102        if isinstance(scope.expression, exp.Select):
103            unqualified_columns = scope.unqualified_columns
104
105            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
106                column = scope.external_columns[0]
107                for_table = f" for table: '{column.table}'" if column.table else ""
108                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
109
110            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
111                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
112                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
113                # this list here to ensure those in the former category will be excluded.
114                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
115                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
116
117            all_unqualified_columns.extend(unqualified_columns)
118
119    if all_unqualified_columns:
120        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
121
122    return expression
123
124
125def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
126    name_column = []
127    field = unpivot.args.get("field")
128    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
129        name_column.append(field.this)
130
131    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
132    return itertools.chain(name_column, value_columns)
133
134
135def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
136    """
137    Remove table column aliases.
138
139    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
140    """
141    for derived_table in derived_tables:
142        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
143            continue
144        table_alias = derived_table.args.get("alias")
145        if table_alias:
146            table_alias.args.pop("columns", None)
147
148
149def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
150    columns = {}
151
152    def _update_source_columns(source_name: str) -> None:
153        for column_name in resolver.get_source_columns(source_name):
154            if column_name not in columns:
155                columns[column_name] = source_name
156
157    joins = list(scope.find_all(exp.Join))
158    names = {join.alias_or_name for join in joins}
159    ordered = [key for key in scope.selected_sources if key not in names]
160
161    # Mapping of automatically joined column names to an ordered set of source names (dict).
162    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
163
164    for source_name in ordered:
165        _update_source_columns(source_name)
166
167    for i, join in enumerate(joins):
168        source_table = ordered[-1]
169        if source_table:
170            _update_source_columns(source_table)
171
172        join_table = join.alias_or_name
173        ordered.append(join_table)
174
175        using = join.args.get("using")
176        if not using:
177            continue
178
179        join_columns = resolver.get_source_columns(join_table)
180        conditions = []
181        using_identifier_count = len(using)
182
183        for identifier in using:
184            identifier = identifier.name
185            table = columns.get(identifier)
186
187            if not table or identifier not in join_columns:
188                if (columns and "*" not in columns) and join_columns:
189                    raise OptimizeError(f"Cannot automatically join: {identifier}")
190
191            table = table or source_table
192
193            if i == 0 or using_identifier_count == 1:
194                lhs: exp.Expression = exp.column(identifier, table=table)
195            else:
196                coalesce_columns = [
197                    exp.column(identifier, table=t)
198                    for t in ordered[:-1]
199                    if identifier in resolver.get_source_columns(t)
200                ]
201                if len(coalesce_columns) > 1:
202                    lhs = exp.func("coalesce", *coalesce_columns)
203                else:
204                    lhs = exp.column(identifier, table=table)
205
206            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
207
208            # Set all values in the dict to None, because we only care about the key ordering
209            tables = column_tables.setdefault(identifier, {})
210            if table not in tables:
211                tables[table] = None
212            if join_table not in tables:
213                tables[join_table] = None
214
215        join.args.pop("using")
216        join.set("on", exp.and_(*conditions, copy=False))
217
218    if column_tables:
219        for column in scope.columns:
220            if not column.table and column.name in column_tables:
221                tables = column_tables[column.name]
222                coalesce_args = [exp.column(column.name, table=table) for table in tables]
223                replacement = exp.func("coalesce", *coalesce_args)
224
225                # Ensure selects keep their output name
226                if isinstance(column.parent, exp.Select):
227                    replacement = alias(replacement, alias=column.name, copy=False)
228
229                scope.replace(column, replacement)
230
231    return column_tables
232
233
234def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
235    expression = scope.expression
236
237    if not isinstance(expression, exp.Select):
238        return
239
240    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
241
242    def replace_columns(
243        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
244    ) -> None:
245        is_group_by = isinstance(node, exp.Group)
246        if not node or (expand_only_groupby and not is_group_by):
247            return
248
249        for column in walk_in_scope(node, prune=lambda node: node.is_star):
250            if not isinstance(column, exp.Column):
251                continue
252
253            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
254            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
255            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
256            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
257            if expand_only_groupby and is_group_by and column.parent is not node:
258                continue
259
260            table = resolver.get_table(column.name) if resolve_table and not column.table else None
261            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
262            double_agg = (
263                (
264                    alias_expr.find(exp.AggFunc)
265                    and (
266                        column.find_ancestor(exp.AggFunc)
267                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
268                    )
269                )
270                if alias_expr
271                else False
272            )
273
274            if table and (not alias_expr or double_agg):
275                column.set("table", table)
276            elif not column.table and alias_expr and not double_agg:
277                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
278                    if literal_index:
279                        column.replace(exp.Literal.number(i))
280                else:
281                    column = column.replace(exp.paren(alias_expr))
282                    simplified = simplify_parens(column)
283                    if simplified is not column:
284                        column.replace(simplified)
285
286    for i, projection in enumerate(expression.selects):
287        replace_columns(projection)
288        if isinstance(projection, exp.Alias):
289            alias_to_expression[projection.alias] = (projection.this, i + 1)
290
291    parent_scope = scope
292    while parent_scope.is_union:
293        parent_scope = parent_scope.parent
294
295    # We shouldn't expand aliases if they match the recursive CTE's columns
296    if parent_scope.is_cte:
297        cte = parent_scope.expression.parent
298        if cte.find_ancestor(exp.With).recursive:
299            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
300                alias_to_expression.pop(recursive_cte_column.output_name, None)
301
302    replace_columns(expression.args.get("where"))
303    replace_columns(expression.args.get("group"), literal_index=True)
304    replace_columns(expression.args.get("having"), resolve_table=True)
305    replace_columns(expression.args.get("qualify"), resolve_table=True)
306
307    scope.clear_cache()
308
309
310def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
311    expression = scope.expression
312    group = expression.args.get("group")
313    if not group:
314        return
315
316    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
317    expression.set("group", group)
318
319
320def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
321    order = scope.expression.args.get("order")
322    if not order:
323        return
324
325    ordereds = order.expressions
326    for ordered, new_expression in zip(
327        ordereds,
328        _expand_positional_references(
329            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
330        ),
331    ):
332        for agg in ordered.find_all(exp.AggFunc):
333            for col in agg.find_all(exp.Column):
334                if not col.table:
335                    col.set("table", resolver.get_table(col.name))
336
337        ordered.set("this", new_expression)
338
339    if scope.expression.args.get("group"):
340        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
341
342        for ordered in ordereds:
343            ordered = ordered.this
344
345            ordered.replace(
346                exp.to_identifier(_select_by_pos(scope, ordered).alias)
347                if ordered.is_int
348                else selects.get(ordered, ordered)
349            )
350
351
352def _expand_positional_references(
353    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
354) -> t.List[exp.Expression]:
355    new_nodes: t.List[exp.Expression] = []
356    ambiguous_projections = None
357
358    for node in expressions:
359        if node.is_int:
360            select = _select_by_pos(scope, t.cast(exp.Literal, node))
361
362            if alias:
363                new_nodes.append(exp.column(select.args["alias"].copy()))
364            else:
365                select = select.this
366
367                if dialect == "bigquery":
368                    if ambiguous_projections is None:
369                        # When a projection name is also a source name and it is referenced in the
370                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
371                        ambiguous_projections = {
372                            s.alias_or_name
373                            for s in scope.expression.selects
374                            if s.alias_or_name in scope.selected_sources
375                        }
376
377                    ambiguous = any(
378                        column.parts[0].name in ambiguous_projections
379                        for column in select.find_all(exp.Column)
380                    )
381                else:
382                    ambiguous = False
383
384                if (
385                    isinstance(select, exp.CONSTANTS)
386                    or select.find(exp.Explode, exp.Unnest)
387                    or ambiguous
388                ):
389                    new_nodes.append(node)
390                else:
391                    new_nodes.append(select.copy())
392        else:
393            new_nodes.append(node)
394
395    return new_nodes
396
397
398def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
399    try:
400        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
401    except IndexError:
402        raise OptimizeError(f"Unknown output column: {node.name}")
403
404
405def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
406    """
407    Converts `Column` instances that represent struct field lookup into chained `Dots`.
408
409    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
410    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
411    """
412    converted = False
413    for column in itertools.chain(scope.columns, scope.stars):
414        if isinstance(column, exp.Dot):
415            continue
416
417        column_table: t.Optional[str | exp.Identifier] = column.table
418        if (
419            column_table
420            and column_table not in scope.sources
421            and (
422                not scope.parent
423                or column_table not in scope.parent.sources
424                or not scope.is_correlated_subquery
425            )
426        ):
427            root, *parts = column.parts
428
429            if root.name in scope.sources:
430                # The struct is already qualified, but we still need to change the AST
431                column_table = root
432                root, *parts = parts
433            else:
434                column_table = resolver.get_table(root.name)
435
436            if column_table:
437                converted = True
438                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
439
440    if converted:
441        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
442        # a `for column in scope.columns` iteration, even though they shouldn't be
443        scope.clear_cache()
444
445
446def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
447    """Disambiguate columns, ensuring each column specifies a source"""
448    for column in scope.columns:
449        column_table = column.table
450        column_name = column.name
451
452        if column_table and column_table in scope.sources:
453            source_columns = resolver.get_source_columns(column_table)
454            if (
455                not allow_partial_qualification
456                and source_columns
457                and column_name not in source_columns
458                and "*" not in source_columns
459            ):
460                raise OptimizeError(f"Unknown column: {column_name}")
461
462        if not column_table:
463            if scope.pivots and not column.find_ancestor(exp.Pivot):
464                # If the column is under the Pivot expression, we need to qualify it
465                # using the name of the pivoted source instead of the pivot's alias
466                column.set("table", exp.to_identifier(scope.pivots[0].alias))
467                continue
468
469            # column_table can be a '' because bigquery unnest has no table alias
470            column_table = resolver.get_table(column_name)
471            if column_table:
472                column.set("table", column_table)
473
474    for pivot in scope.pivots:
475        for column in pivot.find_all(exp.Column):
476            if not column.table and column.name in resolver.all_columns:
477                column_table = resolver.get_table(column.name)
478                if column_table:
479                    column.set("table", column_table)
480
481
482def _expand_struct_stars(
483    expression: exp.Dot,
484) -> t.List[exp.Alias]:
485    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
486
487    dot_column = t.cast(exp.Column, expression.find(exp.Column))
488    if not dot_column.is_type(exp.DataType.Type.STRUCT):
489        return []
490
491    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
492    dot_column = dot_column.copy()
493    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
494
495    # First part is the table name and last part is the star so they can be dropped
496    dot_parts = expression.parts[1:-1]
497
498    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
499    for part in dot_parts[1:]:
500        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
501            # Unable to expand star unless all fields are named
502            if not isinstance(field.this, exp.Identifier):
503                return []
504
505            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
506                starting_struct = field
507                break
508        else:
509            # There is no matching field in the struct
510            return []
511
512    taken_names = set()
513    new_selections = []
514
515    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
516        name = field.name
517
518        # Ambiguous or anonymous fields can't be expanded
519        if name in taken_names or not isinstance(field.this, exp.Identifier):
520            return []
521
522        taken_names.add(name)
523
524        this = field.this.copy()
525        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
526        new_column = exp.column(
527            t.cast(exp.Identifier, root),
528            table=dot_column.args.get("table"),
529            fields=t.cast(t.List[exp.Identifier], parts),
530        )
531        new_selections.append(alias(new_column, this, copy=False))
532
533    return new_selections
534
535
536def _expand_stars(
537    scope: Scope,
538    resolver: Resolver,
539    using_column_tables: t.Dict[str, t.Any],
540    pseudocolumns: t.Set[str],
541    annotator: TypeAnnotator,
542) -> None:
543    """Expand stars to lists of column selections"""
544
545    new_selections: t.List[exp.Expression] = []
546    except_columns: t.Dict[int, t.Set[str]] = {}
547    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
548    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
549
550    coalesced_columns = set()
551    dialect = resolver.schema.dialect
552
553    pivot_output_columns = None
554    pivot_exclude_columns = None
555
556    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
557    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
558        if pivot.unpivot:
559            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
560
561            field = pivot.args.get("field")
562            if isinstance(field, exp.In):
563                pivot_exclude_columns = {
564                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
565                }
566        else:
567            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
568
569            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
570            if not pivot_output_columns:
571                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
572
573    is_bigquery = dialect == "bigquery"
574    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
575        # Found struct expansion, annotate scope ahead of time
576        annotator.annotate_scope(scope)
577
578    for expression in scope.expression.selects:
579        tables = []
580        if isinstance(expression, exp.Star):
581            tables.extend(scope.selected_sources)
582            _add_except_columns(expression, tables, except_columns)
583            _add_replace_columns(expression, tables, replace_columns)
584            _add_rename_columns(expression, tables, rename_columns)
585        elif expression.is_star:
586            if not isinstance(expression, exp.Dot):
587                tables.append(expression.table)
588                _add_except_columns(expression.this, tables, except_columns)
589                _add_replace_columns(expression.this, tables, replace_columns)
590                _add_rename_columns(expression.this, tables, rename_columns)
591            elif is_bigquery:
592                struct_fields = _expand_struct_stars(expression)
593                if struct_fields:
594                    new_selections.extend(struct_fields)
595                    continue
596
597        if not tables:
598            new_selections.append(expression)
599            continue
600
601        for table in tables:
602            if table not in scope.sources:
603                raise OptimizeError(f"Unknown table: {table}")
604
605            columns = resolver.get_source_columns(table, only_visible=True)
606            columns = columns or scope.outer_columns
607
608            if pseudocolumns:
609                columns = [name for name in columns if name.upper() not in pseudocolumns]
610
611            if not columns or "*" in columns:
612                return
613
614            table_id = id(table)
615            columns_to_exclude = except_columns.get(table_id) or set()
616            renamed_columns = rename_columns.get(table_id, {})
617            replaced_columns = replace_columns.get(table_id, {})
618
619            if pivot:
620                if pivot_output_columns and pivot_exclude_columns:
621                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
622                    pivot_columns.extend(pivot_output_columns)
623                else:
624                    pivot_columns = pivot.alias_column_names
625
626                if pivot_columns:
627                    new_selections.extend(
628                        alias(exp.column(name, table=pivot.alias), name, copy=False)
629                        for name in pivot_columns
630                        if name not in columns_to_exclude
631                    )
632                    continue
633
634            for name in columns:
635                if name in columns_to_exclude or name in coalesced_columns:
636                    continue
637                if name in using_column_tables and table in using_column_tables[name]:
638                    coalesced_columns.add(name)
639                    tables = using_column_tables[name]
640                    coalesce_args = [exp.column(name, table=table) for table in tables]
641
642                    new_selections.append(
643                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
644                    )
645                else:
646                    alias_ = renamed_columns.get(name, name)
647                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
648                    new_selections.append(
649                        alias(selection_expr, alias_, copy=False)
650                        if alias_ != name
651                        else selection_expr
652                    )
653
654    # Ensures we don't overwrite the initial selections with an empty list
655    if new_selections and isinstance(scope.expression, exp.Select):
656        scope.expression.set("expressions", new_selections)
657
658
659def _add_except_columns(
660    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
661) -> None:
662    except_ = expression.args.get("except")
663
664    if not except_:
665        return
666
667    columns = {e.name for e in except_}
668
669    for table in tables:
670        except_columns[id(table)] = columns
671
672
673def _add_rename_columns(
674    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
675) -> None:
676    rename = expression.args.get("rename")
677
678    if not rename:
679        return
680
681    columns = {e.this.name: e.alias for e in rename}
682
683    for table in tables:
684        rename_columns[id(table)] = columns
685
686
687def _add_replace_columns(
688    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
689) -> None:
690    replace = expression.args.get("replace")
691
692    if not replace:
693        return
694
695    columns = {e.alias: e for e in replace}
696
697    for table in tables:
698        replace_columns[id(table)] = columns
699
700
701def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
702    """Ensure all output columns are aliased"""
703    if isinstance(scope_or_expression, exp.Expression):
704        scope = build_scope(scope_or_expression)
705        if not isinstance(scope, Scope):
706            return
707    else:
708        scope = scope_or_expression
709
710    new_selections = []
711    for i, (selection, aliased_column) in enumerate(
712        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
713    ):
714        if selection is None:
715            break
716
717        if isinstance(selection, exp.Subquery):
718            if not selection.output_name:
719                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
720        elif not isinstance(selection, exp.Alias) and not selection.is_star:
721            selection = alias(
722                selection,
723                alias=selection.output_name or f"_col_{i}",
724                copy=False,
725            )
726        if aliased_column:
727            selection.set("alias", exp.to_identifier(aliased_column))
728
729        new_selections.append(selection)
730
731    if isinstance(scope.expression, exp.Select):
732        scope.expression.set("expressions", new_selections)
733
734
735def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
736    """Makes sure all identifiers that need to be quoted are quoted."""
737    return expression.transform(
738        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
739    )  # type: ignore
740
741
742def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
743    """
744    Pushes down the CTE alias columns into the projection,
745
746    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
747
748    Example:
749        >>> import sqlglot
750        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
751        >>> pushdown_cte_alias_columns(expression).sql()
752        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
753
754    Args:
755        expression: Expression to pushdown.
756
757    Returns:
758        The expression with the CTE aliases pushed down into the projection.
759    """
760    for cte in expression.find_all(exp.CTE):
761        if cte.alias_column_names:
762            new_expressions = []
763            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
764                if isinstance(projection, exp.Alias):
765                    projection.set("alias", _alias)
766                else:
767                    projection = alias(projection, alias=_alias)
768                new_expressions.append(projection)
769            cte.this.set("expressions", new_expressions)
770
771    return expression
772
773
774class Resolver:
775    """
776    Helper for resolving columns.
777
778    This is a class so we can lazily load some things and easily share them across functions.
779    """
780
781    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
782        self.scope = scope
783        self.schema = schema
784        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
785        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
786        self._all_columns: t.Optional[t.Set[str]] = None
787        self._infer_schema = infer_schema
788        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
789
790    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
791        """
792        Get the table for a column name.
793
794        Args:
795            column_name: The column name to find the table for.
796        Returns:
797            The table name if it can be found/inferred.
798        """
799        if self._unambiguous_columns is None:
800            self._unambiguous_columns = self._get_unambiguous_columns(
801                self._get_all_source_columns()
802            )
803
804        table_name = self._unambiguous_columns.get(column_name)
805
806        if not table_name and self._infer_schema:
807            sources_without_schema = tuple(
808                source
809                for source, columns in self._get_all_source_columns().items()
810                if not columns or "*" in columns
811            )
812            if len(sources_without_schema) == 1:
813                table_name = sources_without_schema[0]
814
815        if table_name not in self.scope.selected_sources:
816            return exp.to_identifier(table_name)
817
818        node, _ = self.scope.selected_sources.get(table_name)
819
820        if isinstance(node, exp.Query):
821            while node and node.alias != table_name:
822                node = node.parent
823
824        node_alias = node.args.get("alias")
825        if node_alias:
826            return exp.to_identifier(node_alias.this)
827
828        return exp.to_identifier(table_name)
829
830    @property
831    def all_columns(self) -> t.Set[str]:
832        """All available columns of all sources in this scope"""
833        if self._all_columns is None:
834            self._all_columns = {
835                column for columns in self._get_all_source_columns().values() for column in columns
836            }
837        return self._all_columns
838
839    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
840        """Resolve the source columns for a given source `name`."""
841        cache_key = (name, only_visible)
842        if cache_key not in self._get_source_columns_cache:
843            if name not in self.scope.sources:
844                raise OptimizeError(f"Unknown table: {name}")
845
846            source = self.scope.sources[name]
847
848            if isinstance(source, exp.Table):
849                columns = self.schema.column_names(source, only_visible)
850            elif isinstance(source, Scope) and isinstance(
851                source.expression, (exp.Values, exp.Unnest)
852            ):
853                columns = source.expression.named_selects
854
855                # in bigquery, unnest structs are automatically scoped as tables, so you can
856                # directly select a struct field in a query.
857                # this handles the case where the unnest is statically defined.
858                if self.schema.dialect == "bigquery":
859                    if source.expression.is_type(exp.DataType.Type.STRUCT):
860                        for k in source.expression.type.expressions:  # type: ignore
861                            columns.append(k.name)
862            else:
863                columns = source.expression.named_selects
864
865            node, _ = self.scope.selected_sources.get(name) or (None, None)
866            if isinstance(node, Scope):
867                column_aliases = node.expression.alias_column_names
868            elif isinstance(node, exp.Expression):
869                column_aliases = node.alias_column_names
870            else:
871                column_aliases = []
872
873            if column_aliases:
874                # If the source's columns are aliased, their aliases shadow the corresponding column names.
875                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
876                columns = [
877                    alias or name
878                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
879                ]
880
881            self._get_source_columns_cache[cache_key] = columns
882
883        return self._get_source_columns_cache[cache_key]
884
885    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
886        if self._source_columns is None:
887            self._source_columns = {
888                source_name: self.get_source_columns(source_name)
889                for source_name, source in itertools.chain(
890                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
891                )
892            }
893        return self._source_columns
894
895    def _get_unambiguous_columns(
896        self, source_columns: t.Dict[str, t.Sequence[str]]
897    ) -> t.Mapping[str, str]:
898        """
899        Find all the unambiguous columns in sources.
900
901        Args:
902            source_columns: Mapping of names to source columns.
903
904        Returns:
905            Mapping of column name to source name.
906        """
907        if not source_columns:
908            return {}
909
910        source_columns_pairs = list(source_columns.items())
911
912        first_table, first_columns = source_columns_pairs[0]
913
914        if len(source_columns_pairs) == 1:
915            # Performance optimization - avoid copying first_columns if there is only one table.
916            return SingleValuedMapping(first_columns, first_table)
917
918        unambiguous_columns = {col: first_table for col in first_columns}
919        all_columns = set(unambiguous_columns)
920
921        for table, columns in source_columns_pairs[1:]:
922            unique = set(columns)
923            ambiguous = all_columns.intersection(unique)
924            all_columns.update(columns)
925
926            for column in ambiguous:
927                unambiguous_columns.pop(column, None)
928            for column in unique.difference(ambiguous):
929                unambiguous_columns[column] = table
930
931        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26    allow_partial_qualification: bool = False,
27) -> exp.Expression:
28    """
29    Rewrite sqlglot AST to have fully qualified columns.
30
31    Example:
32        >>> import sqlglot
33        >>> schema = {"tbl": {"col": "INT"}}
34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
35        >>> qualify_columns(expression, schema).sql()
36        'SELECT tbl.col AS col FROM tbl'
37
38    Args:
39        expression: Expression to qualify.
40        schema: Database schema.
41        expand_alias_refs: Whether to expand references to aliases.
42        expand_stars: Whether to expand star queries. This is a necessary step
43            for most of the optimizer's rules to work; do not set to False unless you
44            know what you're doing!
45        infer_schema: Whether to infer the schema if missing.
46        allow_partial_qualification: Whether to allow partial qualification.
47
48    Returns:
49        The qualified expression.
50
51    Notes:
52        - Currently only handles a single PIVOT or UNPIVOT operator
53    """
54    schema = ensure_schema(schema)
55    annotator = TypeAnnotator(schema)
56    infer_schema = schema.empty if infer_schema is None else infer_schema
57    dialect = Dialect.get_or_raise(schema.dialect)
58    pseudocolumns = dialect.PSEUDOCOLUMNS
59
60    for scope in traverse_scope(expression):
61        resolver = Resolver(scope, schema, infer_schema=infer_schema)
62        _pop_table_column_aliases(scope.ctes)
63        _pop_table_column_aliases(scope.derived_tables)
64        using_column_tables = _expand_using(scope, resolver)
65
66        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
67            _expand_alias_refs(
68                scope,
69                resolver,
70                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
71            )
72
73        _convert_columns_to_dots(scope, resolver)
74        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
75
76        if not schema.empty and expand_alias_refs:
77            _expand_alias_refs(scope, resolver)
78
79        if not isinstance(scope.expression, exp.UDTF):
80            if expand_stars:
81                _expand_stars(
82                    scope,
83                    resolver,
84                    using_column_tables,
85                    pseudocolumns,
86                    annotator,
87                )
88            qualify_outputs(scope)
89
90        _expand_group_by(scope, dialect)
91        _expand_order_by(scope, resolver)
92
93        if dialect == "bigquery":
94            annotator.annotate_scope(scope)
95
96    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 99def validate_qualify_columns(expression: E) -> E:
100    """Raise an `OptimizeError` if any columns aren't qualified"""
101    all_unqualified_columns = []
102    for scope in traverse_scope(expression):
103        if isinstance(scope.expression, exp.Select):
104            unqualified_columns = scope.unqualified_columns
105
106            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
107                column = scope.external_columns[0]
108                for_table = f" for table: '{column.table}'" if column.table else ""
109                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
110
111            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
112                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
113                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
114                # this list here to ensure those in the former category will be excluded.
115                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
116                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
117
118            all_unqualified_columns.extend(unqualified_columns)
119
120    if all_unqualified_columns:
121        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
122
123    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
702def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
703    """Ensure all output columns are aliased"""
704    if isinstance(scope_or_expression, exp.Expression):
705        scope = build_scope(scope_or_expression)
706        if not isinstance(scope, Scope):
707            return
708    else:
709        scope = scope_or_expression
710
711    new_selections = []
712    for i, (selection, aliased_column) in enumerate(
713        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
714    ):
715        if selection is None:
716            break
717
718        if isinstance(selection, exp.Subquery):
719            if not selection.output_name:
720                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
721        elif not isinstance(selection, exp.Alias) and not selection.is_star:
722            selection = alias(
723                selection,
724                alias=selection.output_name or f"_col_{i}",
725                copy=False,
726            )
727        if aliased_column:
728            selection.set("alias", exp.to_identifier(aliased_column))
729
730        new_selections.append(selection)
731
732    if isinstance(scope.expression, exp.Select):
733        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
736def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
737    """Makes sure all identifiers that need to be quoted are quoted."""
738    return expression.transform(
739        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
740    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
743def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
744    """
745    Pushes down the CTE alias columns into the projection,
746
747    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
748
749    Example:
750        >>> import sqlglot
751        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
752        >>> pushdown_cte_alias_columns(expression).sql()
753        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
754
755    Args:
756        expression: Expression to pushdown.
757
758    Returns:
759        The expression with the CTE aliases pushed down into the projection.
760    """
761    for cte in expression.find_all(exp.CTE):
762        if cte.alias_column_names:
763            new_expressions = []
764            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
765                if isinstance(projection, exp.Alias):
766                    projection.set("alias", _alias)
767                else:
768                    projection = alias(projection, alias=_alias)
769                new_expressions.append(projection)
770            cte.this.set("expressions", new_expressions)
771
772    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
775class Resolver:
776    """
777    Helper for resolving columns.
778
779    This is a class so we can lazily load some things and easily share them across functions.
780    """
781
782    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
783        self.scope = scope
784        self.schema = schema
785        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
786        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
787        self._all_columns: t.Optional[t.Set[str]] = None
788        self._infer_schema = infer_schema
789        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
790
791    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
792        """
793        Get the table for a column name.
794
795        Args:
796            column_name: The column name to find the table for.
797        Returns:
798            The table name if it can be found/inferred.
799        """
800        if self._unambiguous_columns is None:
801            self._unambiguous_columns = self._get_unambiguous_columns(
802                self._get_all_source_columns()
803            )
804
805        table_name = self._unambiguous_columns.get(column_name)
806
807        if not table_name and self._infer_schema:
808            sources_without_schema = tuple(
809                source
810                for source, columns in self._get_all_source_columns().items()
811                if not columns or "*" in columns
812            )
813            if len(sources_without_schema) == 1:
814                table_name = sources_without_schema[0]
815
816        if table_name not in self.scope.selected_sources:
817            return exp.to_identifier(table_name)
818
819        node, _ = self.scope.selected_sources.get(table_name)
820
821        if isinstance(node, exp.Query):
822            while node and node.alias != table_name:
823                node = node.parent
824
825        node_alias = node.args.get("alias")
826        if node_alias:
827            return exp.to_identifier(node_alias.this)
828
829        return exp.to_identifier(table_name)
830
831    @property
832    def all_columns(self) -> t.Set[str]:
833        """All available columns of all sources in this scope"""
834        if self._all_columns is None:
835            self._all_columns = {
836                column for columns in self._get_all_source_columns().values() for column in columns
837            }
838        return self._all_columns
839
840    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
841        """Resolve the source columns for a given source `name`."""
842        cache_key = (name, only_visible)
843        if cache_key not in self._get_source_columns_cache:
844            if name not in self.scope.sources:
845                raise OptimizeError(f"Unknown table: {name}")
846
847            source = self.scope.sources[name]
848
849            if isinstance(source, exp.Table):
850                columns = self.schema.column_names(source, only_visible)
851            elif isinstance(source, Scope) and isinstance(
852                source.expression, (exp.Values, exp.Unnest)
853            ):
854                columns = source.expression.named_selects
855
856                # in bigquery, unnest structs are automatically scoped as tables, so you can
857                # directly select a struct field in a query.
858                # this handles the case where the unnest is statically defined.
859                if self.schema.dialect == "bigquery":
860                    if source.expression.is_type(exp.DataType.Type.STRUCT):
861                        for k in source.expression.type.expressions:  # type: ignore
862                            columns.append(k.name)
863            else:
864                columns = source.expression.named_selects
865
866            node, _ = self.scope.selected_sources.get(name) or (None, None)
867            if isinstance(node, Scope):
868                column_aliases = node.expression.alias_column_names
869            elif isinstance(node, exp.Expression):
870                column_aliases = node.alias_column_names
871            else:
872                column_aliases = []
873
874            if column_aliases:
875                # If the source's columns are aliased, their aliases shadow the corresponding column names.
876                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
877                columns = [
878                    alias or name
879                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
880                ]
881
882            self._get_source_columns_cache[cache_key] = columns
883
884        return self._get_source_columns_cache[cache_key]
885
886    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
887        if self._source_columns is None:
888            self._source_columns = {
889                source_name: self.get_source_columns(source_name)
890                for source_name, source in itertools.chain(
891                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
892                )
893            }
894        return self._source_columns
895
896    def _get_unambiguous_columns(
897        self, source_columns: t.Dict[str, t.Sequence[str]]
898    ) -> t.Mapping[str, str]:
899        """
900        Find all the unambiguous columns in sources.
901
902        Args:
903            source_columns: Mapping of names to source columns.
904
905        Returns:
906            Mapping of column name to source name.
907        """
908        if not source_columns:
909            return {}
910
911        source_columns_pairs = list(source_columns.items())
912
913        first_table, first_columns = source_columns_pairs[0]
914
915        if len(source_columns_pairs) == 1:
916            # Performance optimization - avoid copying first_columns if there is only one table.
917            return SingleValuedMapping(first_columns, first_table)
918
919        unambiguous_columns = {col: first_table for col in first_columns}
920        all_columns = set(unambiguous_columns)
921
922        for table, columns in source_columns_pairs[1:]:
923            unique = set(columns)
924            ambiguous = all_columns.intersection(unique)
925            all_columns.update(columns)
926
927            for column in ambiguous:
928                unambiguous_columns.pop(column, None)
929            for column in unique.difference(ambiguous):
930                unambiguous_columns[column] = table
931
932        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
782    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
783        self.scope = scope
784        self.schema = schema
785        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
786        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
787        self._all_columns: t.Optional[t.Set[str]] = None
788        self._infer_schema = infer_schema
789        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
791    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
792        """
793        Get the table for a column name.
794
795        Args:
796            column_name: The column name to find the table for.
797        Returns:
798            The table name if it can be found/inferred.
799        """
800        if self._unambiguous_columns is None:
801            self._unambiguous_columns = self._get_unambiguous_columns(
802                self._get_all_source_columns()
803            )
804
805        table_name = self._unambiguous_columns.get(column_name)
806
807        if not table_name and self._infer_schema:
808            sources_without_schema = tuple(
809                source
810                for source, columns in self._get_all_source_columns().items()
811                if not columns or "*" in columns
812            )
813            if len(sources_without_schema) == 1:
814                table_name = sources_without_schema[0]
815
816        if table_name not in self.scope.selected_sources:
817            return exp.to_identifier(table_name)
818
819        node, _ = self.scope.selected_sources.get(table_name)
820
821        if isinstance(node, exp.Query):
822            while node and node.alias != table_name:
823                node = node.parent
824
825        node_alias = node.args.get("alias")
826        if node_alias:
827            return exp.to_identifier(node_alias.this)
828
829        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
831    @property
832    def all_columns(self) -> t.Set[str]:
833        """All available columns of all sources in this scope"""
834        if self._all_columns is None:
835            self._all_columns = {
836                column for columns in self._get_all_source_columns().values() for column in columns
837            }
838        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
840    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
841        """Resolve the source columns for a given source `name`."""
842        cache_key = (name, only_visible)
843        if cache_key not in self._get_source_columns_cache:
844            if name not in self.scope.sources:
845                raise OptimizeError(f"Unknown table: {name}")
846
847            source = self.scope.sources[name]
848
849            if isinstance(source, exp.Table):
850                columns = self.schema.column_names(source, only_visible)
851            elif isinstance(source, Scope) and isinstance(
852                source.expression, (exp.Values, exp.Unnest)
853            ):
854                columns = source.expression.named_selects
855
856                # in bigquery, unnest structs are automatically scoped as tables, so you can
857                # directly select a struct field in a query.
858                # this handles the case where the unnest is statically defined.
859                if self.schema.dialect == "bigquery":
860                    if source.expression.is_type(exp.DataType.Type.STRUCT):
861                        for k in source.expression.type.expressions:  # type: ignore
862                            columns.append(k.name)
863            else:
864                columns = source.expression.named_selects
865
866            node, _ = self.scope.selected_sources.get(name) or (None, None)
867            if isinstance(node, Scope):
868                column_aliases = node.expression.alias_column_names
869            elif isinstance(node, exp.Expression):
870                column_aliases = node.alias_column_names
871            else:
872                column_aliases = []
873
874            if column_aliases:
875                # If the source's columns are aliased, their aliases shadow the corresponding column names.
876                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
877                columns = [
878                    alias or name
879                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
880                ]
881
882            self._get_source_columns_cache[cache_key] = columns
883
884        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.