Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25    allow_partial_qualification: bool = False,
 26) -> exp.Expression:
 27    """
 28    Rewrite sqlglot AST to have fully qualified columns.
 29
 30    Example:
 31        >>> import sqlglot
 32        >>> schema = {"tbl": {"col": "INT"}}
 33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 34        >>> qualify_columns(expression, schema).sql()
 35        'SELECT tbl.col AS col FROM tbl'
 36
 37    Args:
 38        expression: Expression to qualify.
 39        schema: Database schema.
 40        expand_alias_refs: Whether to expand references to aliases.
 41        expand_stars: Whether to expand star queries. This is a necessary step
 42            for most of the optimizer's rules to work; do not set to False unless you
 43            know what you're doing!
 44        infer_schema: Whether to infer the schema if missing.
 45        allow_partial_qualification: Whether to allow partial qualification.
 46
 47    Returns:
 48        The qualified expression.
 49
 50    Notes:
 51        - Currently only handles a single PIVOT or UNPIVOT operator
 52    """
 53    schema = ensure_schema(schema)
 54    annotator = TypeAnnotator(schema)
 55    infer_schema = schema.empty if infer_schema is None else infer_schema
 56    dialect = Dialect.get_or_raise(schema.dialect)
 57    pseudocolumns = dialect.PSEUDOCOLUMNS
 58
 59    for scope in traverse_scope(expression):
 60        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 61        _pop_table_column_aliases(scope.ctes)
 62        _pop_table_column_aliases(scope.derived_tables)
 63        using_column_tables = _expand_using(scope, resolver)
 64
 65        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 66            _expand_alias_refs(
 67                scope,
 68                resolver,
 69                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 70            )
 71
 72        _convert_columns_to_dots(scope, resolver)
 73        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
 74
 75        if not schema.empty and expand_alias_refs:
 76            _expand_alias_refs(scope, resolver)
 77
 78        if not isinstance(scope.expression, exp.UDTF):
 79            if expand_stars:
 80                _expand_stars(
 81                    scope,
 82                    resolver,
 83                    using_column_tables,
 84                    pseudocolumns,
 85                    annotator,
 86                )
 87            qualify_outputs(scope)
 88
 89        _expand_group_by(scope, dialect)
 90        _expand_order_by(scope, resolver)
 91
 92        if dialect == "bigquery":
 93            annotator.annotate_scope(scope)
 94
 95    return expression
 96
 97
 98def validate_qualify_columns(expression: E) -> E:
 99    """Raise an `OptimizeError` if any columns aren't qualified"""
100    all_unqualified_columns = []
101    for scope in traverse_scope(expression):
102        if isinstance(scope.expression, exp.Select):
103            unqualified_columns = scope.unqualified_columns
104
105            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
106                column = scope.external_columns[0]
107                for_table = f" for table: '{column.table}'" if column.table else ""
108                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
109
110            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
111                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
112                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
113                # this list here to ensure those in the former category will be excluded.
114                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
115                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
116
117            all_unqualified_columns.extend(unqualified_columns)
118
119    if all_unqualified_columns:
120        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
121
122    return expression
123
124
125def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
126    name_column = []
127    field = unpivot.args.get("field")
128    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
129        name_column.append(field.this)
130
131    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
132    return itertools.chain(name_column, value_columns)
133
134
135def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
136    """
137    Remove table column aliases.
138
139    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
140    """
141    for derived_table in derived_tables:
142        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
143            continue
144        table_alias = derived_table.args.get("alias")
145        if table_alias:
146            table_alias.args.pop("columns", None)
147
148
149def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
150    columns = {}
151
152    def _update_source_columns(source_name: str) -> None:
153        for column_name in resolver.get_source_columns(source_name):
154            if column_name not in columns:
155                columns[column_name] = source_name
156
157    joins = list(scope.find_all(exp.Join))
158    names = {join.alias_or_name for join in joins}
159    ordered = [key for key in scope.selected_sources if key not in names]
160
161    # Mapping of automatically joined column names to an ordered set of source names (dict).
162    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
163
164    for source_name in ordered:
165        _update_source_columns(source_name)
166
167    for i, join in enumerate(joins):
168        source_table = ordered[-1]
169        if source_table:
170            _update_source_columns(source_table)
171
172        join_table = join.alias_or_name
173        ordered.append(join_table)
174
175        using = join.args.get("using")
176        if not using:
177            continue
178
179        join_columns = resolver.get_source_columns(join_table)
180        conditions = []
181        using_identifier_count = len(using)
182
183        for identifier in using:
184            identifier = identifier.name
185            table = columns.get(identifier)
186
187            if not table or identifier not in join_columns:
188                if (columns and "*" not in columns) and join_columns:
189                    raise OptimizeError(f"Cannot automatically join: {identifier}")
190
191            table = table or source_table
192
193            if i == 0 or using_identifier_count == 1:
194                lhs: exp.Expression = exp.column(identifier, table=table)
195            else:
196                coalesce_columns = [
197                    exp.column(identifier, table=t)
198                    for t in ordered[:-1]
199                    if identifier in resolver.get_source_columns(t)
200                ]
201                if len(coalesce_columns) > 1:
202                    lhs = exp.func("coalesce", *coalesce_columns)
203                else:
204                    lhs = exp.column(identifier, table=table)
205
206            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
207
208            # Set all values in the dict to None, because we only care about the key ordering
209            tables = column_tables.setdefault(identifier, {})
210            if table not in tables:
211                tables[table] = None
212            if join_table not in tables:
213                tables[join_table] = None
214
215        join.args.pop("using")
216        join.set("on", exp.and_(*conditions, copy=False))
217
218    if column_tables:
219        for column in scope.columns:
220            if not column.table and column.name in column_tables:
221                tables = column_tables[column.name]
222                coalesce_args = [exp.column(column.name, table=table) for table in tables]
223                replacement = exp.func("coalesce", *coalesce_args)
224
225                # Ensure selects keep their output name
226                if isinstance(column.parent, exp.Select):
227                    replacement = alias(replacement, alias=column.name, copy=False)
228
229                scope.replace(column, replacement)
230
231    return column_tables
232
233
234def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
235    expression = scope.expression
236
237    if not isinstance(expression, exp.Select):
238        return
239
240    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
241
242    def replace_columns(
243        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
244    ) -> None:
245        is_group_by = isinstance(node, exp.Group)
246        if not node or (expand_only_groupby and not is_group_by):
247            return
248
249        for column in walk_in_scope(node, prune=lambda node: node.is_star):
250            if not isinstance(column, exp.Column):
251                continue
252
253            # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
254            #   SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
255            #   SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col)  --> Shouldn't be expanded, will result to FUNC(FUNC(col))
256            # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
257            if expand_only_groupby and is_group_by and column.parent is not node:
258                continue
259
260            table = resolver.get_table(column.name) if resolve_table and not column.table else None
261            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
262            double_agg = (
263                (
264                    alias_expr.find(exp.AggFunc)
265                    and (
266                        column.find_ancestor(exp.AggFunc)
267                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
268                    )
269                )
270                if alias_expr
271                else False
272            )
273
274            if table and (not alias_expr or double_agg):
275                column.set("table", table)
276            elif not column.table and alias_expr and not double_agg:
277                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
278                    if literal_index:
279                        column.replace(exp.Literal.number(i))
280                else:
281                    column = column.replace(exp.paren(alias_expr))
282                    simplified = simplify_parens(column)
283                    if simplified is not column:
284                        column.replace(simplified)
285
286    for i, projection in enumerate(expression.selects):
287        replace_columns(projection)
288        if isinstance(projection, exp.Alias):
289            alias_to_expression[projection.alias] = (projection.this, i + 1)
290
291    parent_scope = scope
292    while parent_scope.is_union:
293        parent_scope = parent_scope.parent
294
295    # We shouldn't expand aliases if they match the recursive CTE's columns
296    if parent_scope.is_cte:
297        cte = parent_scope.expression.parent
298        if cte.find_ancestor(exp.With).recursive:
299            for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
300                alias_to_expression.pop(recursive_cte_column.output_name, None)
301
302    replace_columns(expression.args.get("where"))
303    replace_columns(expression.args.get("group"), literal_index=True)
304    replace_columns(expression.args.get("having"), resolve_table=True)
305    replace_columns(expression.args.get("qualify"), resolve_table=True)
306
307    scope.clear_cache()
308
309
310def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
311    expression = scope.expression
312    group = expression.args.get("group")
313    if not group:
314        return
315
316    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
317    expression.set("group", group)
318
319
320def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
321    order = scope.expression.args.get("order")
322    if not order:
323        return
324
325    ordereds = order.expressions
326    for ordered, new_expression in zip(
327        ordereds,
328        _expand_positional_references(
329            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
330        ),
331    ):
332        for agg in ordered.find_all(exp.AggFunc):
333            for col in agg.find_all(exp.Column):
334                if not col.table:
335                    col.set("table", resolver.get_table(col.name))
336
337        ordered.set("this", new_expression)
338
339    if scope.expression.args.get("group"):
340        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
341
342        for ordered in ordereds:
343            ordered = ordered.this
344
345            ordered.replace(
346                exp.to_identifier(_select_by_pos(scope, ordered).alias)
347                if ordered.is_int
348                else selects.get(ordered, ordered)
349            )
350
351
352def _expand_positional_references(
353    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
354) -> t.List[exp.Expression]:
355    new_nodes: t.List[exp.Expression] = []
356    ambiguous_projections = None
357
358    for node in expressions:
359        if node.is_int:
360            select = _select_by_pos(scope, t.cast(exp.Literal, node))
361
362            if alias:
363                new_nodes.append(exp.column(select.args["alias"].copy()))
364            else:
365                select = select.this
366
367                if dialect == "bigquery":
368                    if ambiguous_projections is None:
369                        # When a projection name is also a source name and it is referenced in the
370                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
371                        ambiguous_projections = {
372                            s.alias_or_name
373                            for s in scope.expression.selects
374                            if s.alias_or_name in scope.selected_sources
375                        }
376
377                    ambiguous = any(
378                        column.parts[0].name in ambiguous_projections
379                        for column in select.find_all(exp.Column)
380                    )
381                else:
382                    ambiguous = False
383
384                if (
385                    isinstance(select, exp.CONSTANTS)
386                    or select.find(exp.Explode, exp.Unnest)
387                    or ambiguous
388                ):
389                    new_nodes.append(node)
390                else:
391                    new_nodes.append(select.copy())
392        else:
393            new_nodes.append(node)
394
395    return new_nodes
396
397
398def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
399    try:
400        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
401    except IndexError:
402        raise OptimizeError(f"Unknown output column: {node.name}")
403
404
405def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
406    """
407    Converts `Column` instances that represent struct field lookup into chained `Dots`.
408
409    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
410    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
411    """
412    converted = False
413    for column in itertools.chain(scope.columns, scope.stars):
414        if isinstance(column, exp.Dot):
415            continue
416
417        column_table: t.Optional[str | exp.Identifier] = column.table
418        if (
419            column_table
420            and column_table not in scope.sources
421            and (
422                not scope.parent
423                or column_table not in scope.parent.sources
424                or not scope.is_correlated_subquery
425            )
426        ):
427            root, *parts = column.parts
428
429            if root.name in scope.sources:
430                # The struct is already qualified, but we still need to change the AST
431                column_table = root
432                root, *parts = parts
433            else:
434                column_table = resolver.get_table(root.name)
435
436            if column_table:
437                converted = True
438                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
439
440    if converted:
441        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
442        # a `for column in scope.columns` iteration, even though they shouldn't be
443        scope.clear_cache()
444
445
446def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
447    """Disambiguate columns, ensuring each column specifies a source"""
448    for column in scope.columns:
449        column_table = column.table
450        column_name = column.name
451
452        if column_table and column_table in scope.sources:
453            source_columns = resolver.get_source_columns(column_table)
454            if (
455                not allow_partial_qualification
456                and source_columns
457                and column_name not in source_columns
458                and "*" not in source_columns
459            ):
460                raise OptimizeError(f"Unknown column: {column_name}")
461
462        if not column_table:
463            if scope.pivots and not column.find_ancestor(exp.Pivot):
464                # If the column is under the Pivot expression, we need to qualify it
465                # using the name of the pivoted source instead of the pivot's alias
466                column.set("table", exp.to_identifier(scope.pivots[0].alias))
467                continue
468
469            # column_table can be a '' because bigquery unnest has no table alias
470            column_table = resolver.get_table(column_name)
471            if column_table:
472                column.set("table", column_table)
473
474    for pivot in scope.pivots:
475        for column in pivot.find_all(exp.Column):
476            if not column.table and column.name in resolver.all_columns:
477                column_table = resolver.get_table(column.name)
478                if column_table:
479                    column.set("table", column_table)
480
481
482def _expand_struct_stars(
483    expression: exp.Dot,
484) -> t.List[exp.Alias]:
485    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
486
487    dot_column = t.cast(exp.Column, expression.find(exp.Column))
488    if not dot_column.is_type(exp.DataType.Type.STRUCT):
489        return []
490
491    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
492    dot_column = dot_column.copy()
493    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
494
495    # First part is the table name and last part is the star so they can be dropped
496    dot_parts = expression.parts[1:-1]
497
498    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
499    for part in dot_parts[1:]:
500        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
501            # Unable to expand star unless all fields are named
502            if not isinstance(field.this, exp.Identifier):
503                return []
504
505            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
506                starting_struct = field
507                break
508        else:
509            # There is no matching field in the struct
510            return []
511
512    taken_names = set()
513    new_selections = []
514
515    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
516        name = field.name
517
518        # Ambiguous or anonymous fields can't be expanded
519        if name in taken_names or not isinstance(field.this, exp.Identifier):
520            return []
521
522        taken_names.add(name)
523
524        this = field.this.copy()
525        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
526        new_column = exp.column(
527            t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
528        )
529        new_selections.append(alias(new_column, this, copy=False))
530
531    return new_selections
532
533
534def _expand_stars(
535    scope: Scope,
536    resolver: Resolver,
537    using_column_tables: t.Dict[str, t.Any],
538    pseudocolumns: t.Set[str],
539    annotator: TypeAnnotator,
540) -> None:
541    """Expand stars to lists of column selections"""
542
543    new_selections: t.List[exp.Expression] = []
544    except_columns: t.Dict[int, t.Set[str]] = {}
545    replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
546    rename_columns: t.Dict[int, t.Dict[str, str]] = {}
547
548    coalesced_columns = set()
549    dialect = resolver.schema.dialect
550
551    pivot_output_columns = None
552    pivot_exclude_columns = None
553
554    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
555    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
556        if pivot.unpivot:
557            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
558
559            field = pivot.args.get("field")
560            if isinstance(field, exp.In):
561                pivot_exclude_columns = {
562                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
563                }
564        else:
565            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
566
567            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
568            if not pivot_output_columns:
569                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
570
571    is_bigquery = dialect == "bigquery"
572    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
573        # Found struct expansion, annotate scope ahead of time
574        annotator.annotate_scope(scope)
575
576    for expression in scope.expression.selects:
577        tables = []
578        if isinstance(expression, exp.Star):
579            tables.extend(scope.selected_sources)
580            _add_except_columns(expression, tables, except_columns)
581            _add_replace_columns(expression, tables, replace_columns)
582            _add_rename_columns(expression, tables, rename_columns)
583        elif expression.is_star:
584            if not isinstance(expression, exp.Dot):
585                tables.append(expression.table)
586                _add_except_columns(expression.this, tables, except_columns)
587                _add_replace_columns(expression.this, tables, replace_columns)
588                _add_rename_columns(expression.this, tables, rename_columns)
589            elif is_bigquery:
590                struct_fields = _expand_struct_stars(expression)
591                if struct_fields:
592                    new_selections.extend(struct_fields)
593                    continue
594
595        if not tables:
596            new_selections.append(expression)
597            continue
598
599        for table in tables:
600            if table not in scope.sources:
601                raise OptimizeError(f"Unknown table: {table}")
602
603            columns = resolver.get_source_columns(table, only_visible=True)
604            columns = columns or scope.outer_columns
605
606            if pseudocolumns:
607                columns = [name for name in columns if name.upper() not in pseudocolumns]
608
609            if not columns or "*" in columns:
610                return
611
612            table_id = id(table)
613            columns_to_exclude = except_columns.get(table_id) or set()
614            renamed_columns = rename_columns.get(table_id, {})
615            replaced_columns = replace_columns.get(table_id, {})
616
617            if pivot:
618                if pivot_output_columns and pivot_exclude_columns:
619                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
620                    pivot_columns.extend(pivot_output_columns)
621                else:
622                    pivot_columns = pivot.alias_column_names
623
624                if pivot_columns:
625                    new_selections.extend(
626                        alias(exp.column(name, table=pivot.alias), name, copy=False)
627                        for name in pivot_columns
628                        if name not in columns_to_exclude
629                    )
630                    continue
631
632            for name in columns:
633                if name in columns_to_exclude or name in coalesced_columns:
634                    continue
635                if name in using_column_tables and table in using_column_tables[name]:
636                    coalesced_columns.add(name)
637                    tables = using_column_tables[name]
638                    coalesce_args = [exp.column(name, table=table) for table in tables]
639
640                    new_selections.append(
641                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
642                    )
643                else:
644                    alias_ = renamed_columns.get(name, name)
645                    selection_expr = replaced_columns.get(name) or exp.column(name, table=table)
646                    new_selections.append(
647                        alias(selection_expr, alias_, copy=False)
648                        if alias_ != name
649                        else selection_expr
650                    )
651
652    # Ensures we don't overwrite the initial selections with an empty list
653    if new_selections and isinstance(scope.expression, exp.Select):
654        scope.expression.set("expressions", new_selections)
655
656
657def _add_except_columns(
658    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
659) -> None:
660    except_ = expression.args.get("except")
661
662    if not except_:
663        return
664
665    columns = {e.name for e in except_}
666
667    for table in tables:
668        except_columns[id(table)] = columns
669
670
671def _add_rename_columns(
672    expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]]
673) -> None:
674    rename = expression.args.get("rename")
675
676    if not rename:
677        return
678
679    columns = {e.this.name: e.alias for e in rename}
680
681    for table in tables:
682        rename_columns[id(table)] = columns
683
684
685def _add_replace_columns(
686    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]]
687) -> None:
688    replace = expression.args.get("replace")
689
690    if not replace:
691        return
692
693    columns = {e.alias: e for e in replace}
694
695    for table in tables:
696        replace_columns[id(table)] = columns
697
698
699def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
700    """Ensure all output columns are aliased"""
701    if isinstance(scope_or_expression, exp.Expression):
702        scope = build_scope(scope_or_expression)
703        if not isinstance(scope, Scope):
704            return
705    else:
706        scope = scope_or_expression
707
708    new_selections = []
709    for i, (selection, aliased_column) in enumerate(
710        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
711    ):
712        if selection is None:
713            break
714
715        if isinstance(selection, exp.Subquery):
716            if not selection.output_name:
717                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
718        elif not isinstance(selection, exp.Alias) and not selection.is_star:
719            selection = alias(
720                selection,
721                alias=selection.output_name or f"_col_{i}",
722                copy=False,
723            )
724        if aliased_column:
725            selection.set("alias", exp.to_identifier(aliased_column))
726
727        new_selections.append(selection)
728
729    if isinstance(scope.expression, exp.Select):
730        scope.expression.set("expressions", new_selections)
731
732
733def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
734    """Makes sure all identifiers that need to be quoted are quoted."""
735    return expression.transform(
736        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
737    )  # type: ignore
738
739
740def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
741    """
742    Pushes down the CTE alias columns into the projection,
743
744    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
745
746    Example:
747        >>> import sqlglot
748        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
749        >>> pushdown_cte_alias_columns(expression).sql()
750        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
751
752    Args:
753        expression: Expression to pushdown.
754
755    Returns:
756        The expression with the CTE aliases pushed down into the projection.
757    """
758    for cte in expression.find_all(exp.CTE):
759        if cte.alias_column_names:
760            new_expressions = []
761            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
762                if isinstance(projection, exp.Alias):
763                    projection.set("alias", _alias)
764                else:
765                    projection = alias(projection, alias=_alias)
766                new_expressions.append(projection)
767            cte.this.set("expressions", new_expressions)
768
769    return expression
770
771
772class Resolver:
773    """
774    Helper for resolving columns.
775
776    This is a class so we can lazily load some things and easily share them across functions.
777    """
778
779    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
780        self.scope = scope
781        self.schema = schema
782        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
783        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
784        self._all_columns: t.Optional[t.Set[str]] = None
785        self._infer_schema = infer_schema
786        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
787
788    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
789        """
790        Get the table for a column name.
791
792        Args:
793            column_name: The column name to find the table for.
794        Returns:
795            The table name if it can be found/inferred.
796        """
797        if self._unambiguous_columns is None:
798            self._unambiguous_columns = self._get_unambiguous_columns(
799                self._get_all_source_columns()
800            )
801
802        table_name = self._unambiguous_columns.get(column_name)
803
804        if not table_name and self._infer_schema:
805            sources_without_schema = tuple(
806                source
807                for source, columns in self._get_all_source_columns().items()
808                if not columns or "*" in columns
809            )
810            if len(sources_without_schema) == 1:
811                table_name = sources_without_schema[0]
812
813        if table_name not in self.scope.selected_sources:
814            return exp.to_identifier(table_name)
815
816        node, _ = self.scope.selected_sources.get(table_name)
817
818        if isinstance(node, exp.Query):
819            while node and node.alias != table_name:
820                node = node.parent
821
822        node_alias = node.args.get("alias")
823        if node_alias:
824            return exp.to_identifier(node_alias.this)
825
826        return exp.to_identifier(table_name)
827
828    @property
829    def all_columns(self) -> t.Set[str]:
830        """All available columns of all sources in this scope"""
831        if self._all_columns is None:
832            self._all_columns = {
833                column for columns in self._get_all_source_columns().values() for column in columns
834            }
835        return self._all_columns
836
837    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
838        """Resolve the source columns for a given source `name`."""
839        cache_key = (name, only_visible)
840        if cache_key not in self._get_source_columns_cache:
841            if name not in self.scope.sources:
842                raise OptimizeError(f"Unknown table: {name}")
843
844            source = self.scope.sources[name]
845
846            if isinstance(source, exp.Table):
847                columns = self.schema.column_names(source, only_visible)
848            elif isinstance(source, Scope) and isinstance(
849                source.expression, (exp.Values, exp.Unnest)
850            ):
851                columns = source.expression.named_selects
852
853                # in bigquery, unnest structs are automatically scoped as tables, so you can
854                # directly select a struct field in a query.
855                # this handles the case where the unnest is statically defined.
856                if self.schema.dialect == "bigquery":
857                    if source.expression.is_type(exp.DataType.Type.STRUCT):
858                        for k in source.expression.type.expressions:  # type: ignore
859                            columns.append(k.name)
860            else:
861                columns = source.expression.named_selects
862
863            node, _ = self.scope.selected_sources.get(name) or (None, None)
864            if isinstance(node, Scope):
865                column_aliases = node.expression.alias_column_names
866            elif isinstance(node, exp.Expression):
867                column_aliases = node.alias_column_names
868            else:
869                column_aliases = []
870
871            if column_aliases:
872                # If the source's columns are aliased, their aliases shadow the corresponding column names.
873                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
874                columns = [
875                    alias or name
876                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
877                ]
878
879            self._get_source_columns_cache[cache_key] = columns
880
881        return self._get_source_columns_cache[cache_key]
882
883    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
884        if self._source_columns is None:
885            self._source_columns = {
886                source_name: self.get_source_columns(source_name)
887                for source_name, source in itertools.chain(
888                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
889                )
890            }
891        return self._source_columns
892
893    def _get_unambiguous_columns(
894        self, source_columns: t.Dict[str, t.Sequence[str]]
895    ) -> t.Mapping[str, str]:
896        """
897        Find all the unambiguous columns in sources.
898
899        Args:
900            source_columns: Mapping of names to source columns.
901
902        Returns:
903            Mapping of column name to source name.
904        """
905        if not source_columns:
906            return {}
907
908        source_columns_pairs = list(source_columns.items())
909
910        first_table, first_columns = source_columns_pairs[0]
911
912        if len(source_columns_pairs) == 1:
913            # Performance optimization - avoid copying first_columns if there is only one table.
914            return SingleValuedMapping(first_columns, first_table)
915
916        unambiguous_columns = {col: first_table for col in first_columns}
917        all_columns = set(unambiguous_columns)
918
919        for table, columns in source_columns_pairs[1:]:
920            unique = set(columns)
921            ambiguous = all_columns.intersection(unique)
922            all_columns.update(columns)
923
924            for column in ambiguous:
925                unambiguous_columns.pop(column, None)
926            for column in unique.difference(ambiguous):
927                unambiguous_columns[column] = table
928
929        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26    allow_partial_qualification: bool = False,
27) -> exp.Expression:
28    """
29    Rewrite sqlglot AST to have fully qualified columns.
30
31    Example:
32        >>> import sqlglot
33        >>> schema = {"tbl": {"col": "INT"}}
34        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
35        >>> qualify_columns(expression, schema).sql()
36        'SELECT tbl.col AS col FROM tbl'
37
38    Args:
39        expression: Expression to qualify.
40        schema: Database schema.
41        expand_alias_refs: Whether to expand references to aliases.
42        expand_stars: Whether to expand star queries. This is a necessary step
43            for most of the optimizer's rules to work; do not set to False unless you
44            know what you're doing!
45        infer_schema: Whether to infer the schema if missing.
46        allow_partial_qualification: Whether to allow partial qualification.
47
48    Returns:
49        The qualified expression.
50
51    Notes:
52        - Currently only handles a single PIVOT or UNPIVOT operator
53    """
54    schema = ensure_schema(schema)
55    annotator = TypeAnnotator(schema)
56    infer_schema = schema.empty if infer_schema is None else infer_schema
57    dialect = Dialect.get_or_raise(schema.dialect)
58    pseudocolumns = dialect.PSEUDOCOLUMNS
59
60    for scope in traverse_scope(expression):
61        resolver = Resolver(scope, schema, infer_schema=infer_schema)
62        _pop_table_column_aliases(scope.ctes)
63        _pop_table_column_aliases(scope.derived_tables)
64        using_column_tables = _expand_using(scope, resolver)
65
66        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
67            _expand_alias_refs(
68                scope,
69                resolver,
70                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
71            )
72
73        _convert_columns_to_dots(scope, resolver)
74        _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
75
76        if not schema.empty and expand_alias_refs:
77            _expand_alias_refs(scope, resolver)
78
79        if not isinstance(scope.expression, exp.UDTF):
80            if expand_stars:
81                _expand_stars(
82                    scope,
83                    resolver,
84                    using_column_tables,
85                    pseudocolumns,
86                    annotator,
87                )
88            qualify_outputs(scope)
89
90        _expand_group_by(scope, dialect)
91        _expand_order_by(scope, resolver)
92
93        if dialect == "bigquery":
94            annotator.annotate_scope(scope)
95
96    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
  • allow_partial_qualification: Whether to allow partial qualification.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 99def validate_qualify_columns(expression: E) -> E:
100    """Raise an `OptimizeError` if any columns aren't qualified"""
101    all_unqualified_columns = []
102    for scope in traverse_scope(expression):
103        if isinstance(scope.expression, exp.Select):
104            unqualified_columns = scope.unqualified_columns
105
106            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
107                column = scope.external_columns[0]
108                for_table = f" for table: '{column.table}'" if column.table else ""
109                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
110
111            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
112                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
113                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
114                # this list here to ensure those in the former category will be excluded.
115                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
116                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
117
118            all_unqualified_columns.extend(unqualified_columns)
119
120    if all_unqualified_columns:
121        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
122
123    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
700def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
701    """Ensure all output columns are aliased"""
702    if isinstance(scope_or_expression, exp.Expression):
703        scope = build_scope(scope_or_expression)
704        if not isinstance(scope, Scope):
705            return
706    else:
707        scope = scope_or_expression
708
709    new_selections = []
710    for i, (selection, aliased_column) in enumerate(
711        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
712    ):
713        if selection is None:
714            break
715
716        if isinstance(selection, exp.Subquery):
717            if not selection.output_name:
718                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
719        elif not isinstance(selection, exp.Alias) and not selection.is_star:
720            selection = alias(
721                selection,
722                alias=selection.output_name or f"_col_{i}",
723                copy=False,
724            )
725        if aliased_column:
726            selection.set("alias", exp.to_identifier(aliased_column))
727
728        new_selections.append(selection)
729
730    if isinstance(scope.expression, exp.Select):
731        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
734def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
735    """Makes sure all identifiers that need to be quoted are quoted."""
736    return expression.transform(
737        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
738    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
741def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
742    """
743    Pushes down the CTE alias columns into the projection,
744
745    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
746
747    Example:
748        >>> import sqlglot
749        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
750        >>> pushdown_cte_alias_columns(expression).sql()
751        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
752
753    Args:
754        expression: Expression to pushdown.
755
756    Returns:
757        The expression with the CTE aliases pushed down into the projection.
758    """
759    for cte in expression.find_all(exp.CTE):
760        if cte.alias_column_names:
761            new_expressions = []
762            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
763                if isinstance(projection, exp.Alias):
764                    projection.set("alias", _alias)
765                else:
766                    projection = alias(projection, alias=_alias)
767                new_expressions.append(projection)
768            cte.this.set("expressions", new_expressions)
769
770    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
773class Resolver:
774    """
775    Helper for resolving columns.
776
777    This is a class so we can lazily load some things and easily share them across functions.
778    """
779
780    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
781        self.scope = scope
782        self.schema = schema
783        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
784        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
785        self._all_columns: t.Optional[t.Set[str]] = None
786        self._infer_schema = infer_schema
787        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
788
789    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
790        """
791        Get the table for a column name.
792
793        Args:
794            column_name: The column name to find the table for.
795        Returns:
796            The table name if it can be found/inferred.
797        """
798        if self._unambiguous_columns is None:
799            self._unambiguous_columns = self._get_unambiguous_columns(
800                self._get_all_source_columns()
801            )
802
803        table_name = self._unambiguous_columns.get(column_name)
804
805        if not table_name and self._infer_schema:
806            sources_without_schema = tuple(
807                source
808                for source, columns in self._get_all_source_columns().items()
809                if not columns or "*" in columns
810            )
811            if len(sources_without_schema) == 1:
812                table_name = sources_without_schema[0]
813
814        if table_name not in self.scope.selected_sources:
815            return exp.to_identifier(table_name)
816
817        node, _ = self.scope.selected_sources.get(table_name)
818
819        if isinstance(node, exp.Query):
820            while node and node.alias != table_name:
821                node = node.parent
822
823        node_alias = node.args.get("alias")
824        if node_alias:
825            return exp.to_identifier(node_alias.this)
826
827        return exp.to_identifier(table_name)
828
829    @property
830    def all_columns(self) -> t.Set[str]:
831        """All available columns of all sources in this scope"""
832        if self._all_columns is None:
833            self._all_columns = {
834                column for columns in self._get_all_source_columns().values() for column in columns
835            }
836        return self._all_columns
837
838    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
839        """Resolve the source columns for a given source `name`."""
840        cache_key = (name, only_visible)
841        if cache_key not in self._get_source_columns_cache:
842            if name not in self.scope.sources:
843                raise OptimizeError(f"Unknown table: {name}")
844
845            source = self.scope.sources[name]
846
847            if isinstance(source, exp.Table):
848                columns = self.schema.column_names(source, only_visible)
849            elif isinstance(source, Scope) and isinstance(
850                source.expression, (exp.Values, exp.Unnest)
851            ):
852                columns = source.expression.named_selects
853
854                # in bigquery, unnest structs are automatically scoped as tables, so you can
855                # directly select a struct field in a query.
856                # this handles the case where the unnest is statically defined.
857                if self.schema.dialect == "bigquery":
858                    if source.expression.is_type(exp.DataType.Type.STRUCT):
859                        for k in source.expression.type.expressions:  # type: ignore
860                            columns.append(k.name)
861            else:
862                columns = source.expression.named_selects
863
864            node, _ = self.scope.selected_sources.get(name) or (None, None)
865            if isinstance(node, Scope):
866                column_aliases = node.expression.alias_column_names
867            elif isinstance(node, exp.Expression):
868                column_aliases = node.alias_column_names
869            else:
870                column_aliases = []
871
872            if column_aliases:
873                # If the source's columns are aliased, their aliases shadow the corresponding column names.
874                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
875                columns = [
876                    alias or name
877                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
878                ]
879
880            self._get_source_columns_cache[cache_key] = columns
881
882        return self._get_source_columns_cache[cache_key]
883
884    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
885        if self._source_columns is None:
886            self._source_columns = {
887                source_name: self.get_source_columns(source_name)
888                for source_name, source in itertools.chain(
889                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
890                )
891            }
892        return self._source_columns
893
894    def _get_unambiguous_columns(
895        self, source_columns: t.Dict[str, t.Sequence[str]]
896    ) -> t.Mapping[str, str]:
897        """
898        Find all the unambiguous columns in sources.
899
900        Args:
901            source_columns: Mapping of names to source columns.
902
903        Returns:
904            Mapping of column name to source name.
905        """
906        if not source_columns:
907            return {}
908
909        source_columns_pairs = list(source_columns.items())
910
911        first_table, first_columns = source_columns_pairs[0]
912
913        if len(source_columns_pairs) == 1:
914            # Performance optimization - avoid copying first_columns if there is only one table.
915            return SingleValuedMapping(first_columns, first_table)
916
917        unambiguous_columns = {col: first_table for col in first_columns}
918        all_columns = set(unambiguous_columns)
919
920        for table, columns in source_columns_pairs[1:]:
921            unique = set(columns)
922            ambiguous = all_columns.intersection(unique)
923            all_columns.update(columns)
924
925            for column in ambiguous:
926                unambiguous_columns.pop(column, None)
927            for column in unique.difference(ambiguous):
928                unambiguous_columns[column] = table
929
930        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
780    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
781        self.scope = scope
782        self.schema = schema
783        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
784        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
785        self._all_columns: t.Optional[t.Set[str]] = None
786        self._infer_schema = infer_schema
787        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
789    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
790        """
791        Get the table for a column name.
792
793        Args:
794            column_name: The column name to find the table for.
795        Returns:
796            The table name if it can be found/inferred.
797        """
798        if self._unambiguous_columns is None:
799            self._unambiguous_columns = self._get_unambiguous_columns(
800                self._get_all_source_columns()
801            )
802
803        table_name = self._unambiguous_columns.get(column_name)
804
805        if not table_name and self._infer_schema:
806            sources_without_schema = tuple(
807                source
808                for source, columns in self._get_all_source_columns().items()
809                if not columns or "*" in columns
810            )
811            if len(sources_without_schema) == 1:
812                table_name = sources_without_schema[0]
813
814        if table_name not in self.scope.selected_sources:
815            return exp.to_identifier(table_name)
816
817        node, _ = self.scope.selected_sources.get(table_name)
818
819        if isinstance(node, exp.Query):
820            while node and node.alias != table_name:
821                node = node.parent
822
823        node_alias = node.args.get("alias")
824        if node_alias:
825            return exp.to_identifier(node_alias.this)
826
827        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
829    @property
830    def all_columns(self) -> t.Set[str]:
831        """All available columns of all sources in this scope"""
832        if self._all_columns is None:
833            self._all_columns = {
834                column for columns in self._get_all_source_columns().values() for column in columns
835            }
836        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
838    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
839        """Resolve the source columns for a given source `name`."""
840        cache_key = (name, only_visible)
841        if cache_key not in self._get_source_columns_cache:
842            if name not in self.scope.sources:
843                raise OptimizeError(f"Unknown table: {name}")
844
845            source = self.scope.sources[name]
846
847            if isinstance(source, exp.Table):
848                columns = self.schema.column_names(source, only_visible)
849            elif isinstance(source, Scope) and isinstance(
850                source.expression, (exp.Values, exp.Unnest)
851            ):
852                columns = source.expression.named_selects
853
854                # in bigquery, unnest structs are automatically scoped as tables, so you can
855                # directly select a struct field in a query.
856                # this handles the case where the unnest is statically defined.
857                if self.schema.dialect == "bigquery":
858                    if source.expression.is_type(exp.DataType.Type.STRUCT):
859                        for k in source.expression.type.expressions:  # type: ignore
860                            columns.append(k.name)
861            else:
862                columns = source.expression.named_selects
863
864            node, _ = self.scope.selected_sources.get(name) or (None, None)
865            if isinstance(node, Scope):
866                column_aliases = node.expression.alias_column_names
867            elif isinstance(node, exp.Expression):
868                column_aliases = node.alias_column_names
869            else:
870                column_aliases = []
871
872            if column_aliases:
873                # If the source's columns are aliased, their aliases shadow the corresponding column names.
874                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
875                columns = [
876                    alias or name
877                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
878                ]
879
880            self._get_source_columns_cache[cache_key] = columns
881
882        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.