Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 11from sqlglot.optimizer.simplify import simplify_parens
 12from sqlglot.schema import Schema, ensure_schema
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18def qualify_columns(
 19    expression: exp.Expression,
 20    schema: t.Dict | Schema,
 21    expand_alias_refs: bool = True,
 22    expand_stars: bool = True,
 23    infer_schema: t.Optional[bool] = None,
 24) -> exp.Expression:
 25    """
 26    Rewrite sqlglot AST to have fully qualified columns.
 27
 28    Example:
 29        >>> import sqlglot
 30        >>> schema = {"tbl": {"col": "INT"}}
 31        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 32        >>> qualify_columns(expression, schema).sql()
 33        'SELECT tbl.col AS col FROM tbl'
 34
 35    Args:
 36        expression: Expression to qualify.
 37        schema: Database schema.
 38        expand_alias_refs: Whether to expand references to aliases.
 39        expand_stars: Whether to expand star queries. This is a necessary step
 40            for most of the optimizer's rules to work; do not set to False unless you
 41            know what you're doing!
 42        infer_schema: Whether to infer the schema if missing.
 43
 44    Returns:
 45        The qualified expression.
 46
 47    Notes:
 48        - Currently only handles a single PIVOT or UNPIVOT operator
 49    """
 50    schema = ensure_schema(schema)
 51    infer_schema = schema.empty if infer_schema is None else infer_schema
 52    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 53
 54    for scope in traverse_scope(expression):
 55        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 56        _pop_table_column_aliases(scope.ctes)
 57        _pop_table_column_aliases(scope.derived_tables)
 58        using_column_tables = _expand_using(scope, resolver)
 59
 60        if schema.empty and expand_alias_refs:
 61            _expand_alias_refs(scope, resolver)
 62
 63        _qualify_columns(scope, resolver)
 64
 65        if not schema.empty and expand_alias_refs:
 66            _expand_alias_refs(scope, resolver)
 67
 68        if not isinstance(scope.expression, exp.UDTF):
 69            if expand_stars:
 70                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 71            qualify_outputs(scope)
 72
 73        _expand_group_by(scope)
 74        _expand_order_by(scope, resolver)
 75
 76    return expression
 77
 78
 79def validate_qualify_columns(expression: E) -> E:
 80    """Raise an `OptimizeError` if any columns aren't qualified"""
 81    all_unqualified_columns = []
 82    for scope in traverse_scope(expression):
 83        if isinstance(scope.expression, exp.Select):
 84            unqualified_columns = scope.unqualified_columns
 85
 86            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 87                column = scope.external_columns[0]
 88                for_table = f" for table: '{column.table}'" if column.table else ""
 89                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 90
 91            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 92                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 93                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 94                # this list here to ensure those in the former category will be excluded.
 95                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 96                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 97
 98            all_unqualified_columns.extend(unqualified_columns)
 99
100    if all_unqualified_columns:
101        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
102
103    return expression
104
105
106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
107    name_column = []
108    field = unpivot.args.get("field")
109    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
110        name_column.append(field.this)
111
112    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
113    return itertools.chain(name_column, value_columns)
114
115
116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
117    """
118    Remove table column aliases.
119
120    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
121    """
122    for derived_table in derived_tables:
123        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
124            continue
125        table_alias = derived_table.args.get("alias")
126        if table_alias:
127            table_alias.args.pop("columns", None)
128
129
130def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
131    joins = list(scope.find_all(exp.Join))
132    names = {join.alias_or_name for join in joins}
133    ordered = [key for key in scope.selected_sources if key not in names]
134
135    # Mapping of automatically joined column names to an ordered set of source names (dict).
136    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
137
138    for join in joins:
139        using = join.args.get("using")
140
141        if not using:
142            continue
143
144        join_table = join.alias_or_name
145
146        columns = {}
147
148        for source_name in scope.selected_sources:
149            if source_name in ordered:
150                for column_name in resolver.get_source_columns(source_name):
151                    if column_name not in columns:
152                        columns[column_name] = source_name
153
154        source_table = ordered[-1]
155        ordered.append(join_table)
156        join_columns = resolver.get_source_columns(join_table)
157        conditions = []
158
159        for identifier in using:
160            identifier = identifier.name
161            table = columns.get(identifier)
162
163            if not table or identifier not in join_columns:
164                if (columns and "*" not in columns) and join_columns:
165                    raise OptimizeError(f"Cannot automatically join: {identifier}")
166
167            table = table or source_table
168            conditions.append(
169                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
170            )
171
172            # Set all values in the dict to None, because we only care about the key ordering
173            tables = column_tables.setdefault(identifier, {})
174            if table not in tables:
175                tables[table] = None
176            if join_table not in tables:
177                tables[join_table] = None
178
179        join.args.pop("using")
180        join.set("on", exp.and_(*conditions, copy=False))
181
182    if column_tables:
183        for column in scope.columns:
184            if not column.table and column.name in column_tables:
185                tables = column_tables[column.name]
186                coalesce = [exp.column(column.name, table=table) for table in tables]
187                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
188
189                # Ensure selects keep their output name
190                if isinstance(column.parent, exp.Select):
191                    replacement = alias(replacement, alias=column.name, copy=False)
192
193                scope.replace(column, replacement)
194
195    return column_tables
196
197
198def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
199    expression = scope.expression
200
201    if not isinstance(expression, exp.Select):
202        return
203
204    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
205
206    def replace_columns(
207        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
208    ) -> None:
209        if not node:
210            return
211
212        for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
213            if not isinstance(column, exp.Column):
214                continue
215
216            table = resolver.get_table(column.name) if resolve_table and not column.table else None
217            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
218            double_agg = (
219                (
220                    alias_expr.find(exp.AggFunc)
221                    and (
222                        column.find_ancestor(exp.AggFunc)
223                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
224                    )
225                )
226                if alias_expr
227                else False
228            )
229
230            if table and (not alias_expr or double_agg):
231                column.set("table", table)
232            elif not column.table and alias_expr and not double_agg:
233                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
234                    if literal_index:
235                        column.replace(exp.Literal.number(i))
236                else:
237                    column = column.replace(exp.paren(alias_expr))
238                    simplified = simplify_parens(column)
239                    if simplified is not column:
240                        column.replace(simplified)
241
242    for i, projection in enumerate(scope.expression.selects):
243        replace_columns(projection)
244
245        if isinstance(projection, exp.Alias):
246            alias_to_expression[projection.alias] = (projection.this, i + 1)
247
248    replace_columns(expression.args.get("where"))
249    replace_columns(expression.args.get("group"), literal_index=True)
250    replace_columns(expression.args.get("having"), resolve_table=True)
251    replace_columns(expression.args.get("qualify"), resolve_table=True)
252
253    scope.clear_cache()
254
255
256def _expand_group_by(scope: Scope) -> None:
257    expression = scope.expression
258    group = expression.args.get("group")
259    if not group:
260        return
261
262    group.set("expressions", _expand_positional_references(scope, group.expressions))
263    expression.set("group", group)
264
265
266def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
267    order = scope.expression.args.get("order")
268    if not order:
269        return
270
271    ordereds = order.expressions
272    for ordered, new_expression in zip(
273        ordereds,
274        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
275    ):
276        for agg in ordered.find_all(exp.AggFunc):
277            for col in agg.find_all(exp.Column):
278                if not col.table:
279                    col.set("table", resolver.get_table(col.name))
280
281        ordered.set("this", new_expression)
282
283    if scope.expression.args.get("group"):
284        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
285
286        for ordered in ordereds:
287            ordered = ordered.this
288
289            ordered.replace(
290                exp.to_identifier(_select_by_pos(scope, ordered).alias)
291                if ordered.is_int
292                else selects.get(ordered, ordered)
293            )
294
295
296def _expand_positional_references(
297    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
298) -> t.List[exp.Expression]:
299    new_nodes: t.List[exp.Expression] = []
300    for node in expressions:
301        if node.is_int:
302            select = _select_by_pos(scope, t.cast(exp.Literal, node))
303
304            if alias:
305                new_nodes.append(exp.column(select.args["alias"].copy()))
306            else:
307                select = select.this
308
309                if isinstance(select, exp.Literal):
310                    new_nodes.append(node)
311                else:
312                    new_nodes.append(select.copy())
313        else:
314            new_nodes.append(node)
315
316    return new_nodes
317
318
319def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
320    try:
321        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
322    except IndexError:
323        raise OptimizeError(f"Unknown output column: {node.name}")
324
325
326def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
327    """Disambiguate columns, ensuring each column specifies a source"""
328    for column in scope.columns:
329        column_table = column.table
330        column_name = column.name
331
332        if column_table and column_table in scope.sources:
333            source_columns = resolver.get_source_columns(column_table)
334            if source_columns and column_name not in source_columns and "*" not in source_columns:
335                raise OptimizeError(f"Unknown column: {column_name}")
336
337        if not column_table:
338            if scope.pivots and not column.find_ancestor(exp.Pivot):
339                # If the column is under the Pivot expression, we need to qualify it
340                # using the name of the pivoted source instead of the pivot's alias
341                column.set("table", exp.to_identifier(scope.pivots[0].alias))
342                continue
343
344            column_table = resolver.get_table(column_name)
345
346            # column_table can be a '' because bigquery unnest has no table alias
347            if column_table:
348                column.set("table", column_table)
349        elif column_table not in scope.sources and (
350            not scope.parent
351            or column_table not in scope.parent.sources
352            or not scope.is_correlated_subquery
353        ):
354            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
355            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
356
357            root, *parts = column.parts
358
359            if root.name in scope.sources:
360                # struct is already qualified, but we still need to change the AST representation
361                column_table = root
362                root, *parts = parts
363            else:
364                column_table = resolver.get_table(root.name)
365
366            if column_table:
367                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
368
369    for pivot in scope.pivots:
370        for column in pivot.find_all(exp.Column):
371            if not column.table and column.name in resolver.all_columns:
372                column_table = resolver.get_table(column.name)
373                if column_table:
374                    column.set("table", column_table)
375
376
377def _expand_stars(
378    scope: Scope,
379    resolver: Resolver,
380    using_column_tables: t.Dict[str, t.Any],
381    pseudocolumns: t.Set[str],
382) -> None:
383    """Expand stars to lists of column selections"""
384
385    new_selections = []
386    except_columns: t.Dict[int, t.Set[str]] = {}
387    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
388    coalesced_columns = set()
389
390    pivot_output_columns = None
391    pivot_exclude_columns = None
392
393    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
394    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
395        if pivot.unpivot:
396            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
397
398            field = pivot.args.get("field")
399            if isinstance(field, exp.In):
400                pivot_exclude_columns = {
401                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
402                }
403        else:
404            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
405
406            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
407            if not pivot_output_columns:
408                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
409
410    for expression in scope.expression.selects:
411        if isinstance(expression, exp.Star):
412            tables = list(scope.selected_sources)
413            _add_except_columns(expression, tables, except_columns)
414            _add_replace_columns(expression, tables, replace_columns)
415        elif expression.is_star and not isinstance(expression, exp.Dot):
416            tables = [expression.table]
417            _add_except_columns(expression.this, tables, except_columns)
418            _add_replace_columns(expression.this, tables, replace_columns)
419        else:
420            new_selections.append(expression)
421            continue
422
423        for table in tables:
424            if table not in scope.sources:
425                raise OptimizeError(f"Unknown table: {table}")
426
427            columns = resolver.get_source_columns(table, only_visible=True)
428            columns = columns or scope.outer_column_list
429
430            if pseudocolumns:
431                columns = [name for name in columns if name.upper() not in pseudocolumns]
432
433            if not columns or "*" in columns:
434                return
435
436            table_id = id(table)
437            columns_to_exclude = except_columns.get(table_id) or set()
438
439            if pivot:
440                if pivot_output_columns and pivot_exclude_columns:
441                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
442                    pivot_columns.extend(pivot_output_columns)
443                else:
444                    pivot_columns = pivot.alias_column_names
445
446                if pivot_columns:
447                    new_selections.extend(
448                        alias(exp.column(name, table=pivot.alias), name, copy=False)
449                        for name in pivot_columns
450                        if name not in columns_to_exclude
451                    )
452                    continue
453
454            for name in columns:
455                if name in columns_to_exclude or name in coalesced_columns:
456                    continue
457                if name in using_column_tables and table in using_column_tables[name]:
458                    coalesced_columns.add(name)
459                    tables = using_column_tables[name]
460                    coalesce = [exp.column(name, table=table) for table in tables]
461
462                    new_selections.append(
463                        alias(
464                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
465                            alias=name,
466                            copy=False,
467                        )
468                    )
469                else:
470                    alias_ = replace_columns.get(table_id, {}).get(name, name)
471                    column = exp.column(name, table=table)
472                    new_selections.append(
473                        alias(column, alias_, copy=False) if alias_ != name else column
474                    )
475
476    # Ensures we don't overwrite the initial selections with an empty list
477    if new_selections and isinstance(scope.expression, exp.Select):
478        scope.expression.set("expressions", new_selections)
479
480
481def _add_except_columns(
482    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
483) -> None:
484    except_ = expression.args.get("except")
485
486    if not except_:
487        return
488
489    columns = {e.name for e in except_}
490
491    for table in tables:
492        except_columns[id(table)] = columns
493
494
495def _add_replace_columns(
496    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
497) -> None:
498    replace = expression.args.get("replace")
499
500    if not replace:
501        return
502
503    columns = {e.this.name: e.alias for e in replace}
504
505    for table in tables:
506        replace_columns[id(table)] = columns
507
508
509def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
510    """Ensure all output columns are aliased"""
511    if isinstance(scope_or_expression, exp.Expression):
512        scope = build_scope(scope_or_expression)
513        if not isinstance(scope, Scope):
514            return
515    else:
516        scope = scope_or_expression
517
518    new_selections = []
519    for i, (selection, aliased_column) in enumerate(
520        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
521    ):
522        if selection is None:
523            break
524
525        if isinstance(selection, exp.Subquery):
526            if not selection.output_name:
527                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
528        elif not isinstance(selection, exp.Alias) and not selection.is_star:
529            selection = alias(
530                selection,
531                alias=selection.output_name or f"_col_{i}",
532                copy=False,
533            )
534        if aliased_column:
535            selection.set("alias", exp.to_identifier(aliased_column))
536
537        new_selections.append(selection)
538
539    if isinstance(scope.expression, exp.Select):
540        scope.expression.set("expressions", new_selections)
541
542
543def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
544    """Makes sure all identifiers that need to be quoted are quoted."""
545    return expression.transform(
546        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
547    )
548
549
550def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
551    """
552    Pushes down the CTE alias columns into the projection,
553
554    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
555
556    Example:
557        >>> import sqlglot
558        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
559        >>> pushdown_cte_alias_columns(expression).sql()
560        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
561
562    Args:
563        expression: Expression to pushdown.
564
565    Returns:
566        The expression with the CTE aliases pushed down into the projection.
567    """
568    for cte in expression.find_all(exp.CTE):
569        if cte.alias_column_names:
570            new_expressions = []
571            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
572                if isinstance(projection, exp.Alias):
573                    projection.set("alias", _alias)
574                else:
575                    projection = alias(projection, alias=_alias)
576                new_expressions.append(projection)
577            cte.this.set("expressions", new_expressions)
578
579    return expression
580
581
582class Resolver:
583    """
584    Helper for resolving columns.
585
586    This is a class so we can lazily load some things and easily share them across functions.
587    """
588
589    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
590        self.scope = scope
591        self.schema = schema
592        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
593        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
594        self._all_columns: t.Optional[t.Set[str]] = None
595        self._infer_schema = infer_schema
596
597    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
598        """
599        Get the table for a column name.
600
601        Args:
602            column_name: The column name to find the table for.
603        Returns:
604            The table name if it can be found/inferred.
605        """
606        if self._unambiguous_columns is None:
607            self._unambiguous_columns = self._get_unambiguous_columns(
608                self._get_all_source_columns()
609            )
610
611        table_name = self._unambiguous_columns.get(column_name)
612
613        if not table_name and self._infer_schema:
614            sources_without_schema = tuple(
615                source
616                for source, columns in self._get_all_source_columns().items()
617                if not columns or "*" in columns
618            )
619            if len(sources_without_schema) == 1:
620                table_name = sources_without_schema[0]
621
622        if table_name not in self.scope.selected_sources:
623            return exp.to_identifier(table_name)
624
625        node, _ = self.scope.selected_sources.get(table_name)
626
627        if isinstance(node, exp.Query):
628            while node and node.alias != table_name:
629                node = node.parent
630
631        node_alias = node.args.get("alias")
632        if node_alias:
633            return exp.to_identifier(node_alias.this)
634
635        return exp.to_identifier(table_name)
636
637    @property
638    def all_columns(self) -> t.Set[str]:
639        """All available columns of all sources in this scope"""
640        if self._all_columns is None:
641            self._all_columns = {
642                column for columns in self._get_all_source_columns().values() for column in columns
643            }
644        return self._all_columns
645
646    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
647        """Resolve the source columns for a given source `name`."""
648        if name not in self.scope.sources:
649            raise OptimizeError(f"Unknown table: {name}")
650
651        source = self.scope.sources[name]
652
653        if isinstance(source, exp.Table):
654            columns = self.schema.column_names(source, only_visible)
655        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
656            columns = source.expression.alias_column_names
657        else:
658            columns = source.expression.named_selects
659
660        node, _ = self.scope.selected_sources.get(name) or (None, None)
661        if isinstance(node, Scope):
662            column_aliases = node.expression.alias_column_names
663        elif isinstance(node, exp.Expression):
664            column_aliases = node.alias_column_names
665        else:
666            column_aliases = []
667
668        if column_aliases:
669            # If the source's columns are aliased, their aliases shadow the corresponding column names.
670            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
671            return [
672                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
673            ]
674        return columns
675
676    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
677        if self._source_columns is None:
678            self._source_columns = {
679                source_name: self.get_source_columns(source_name)
680                for source_name, source in itertools.chain(
681                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
682                )
683            }
684        return self._source_columns
685
686    def _get_unambiguous_columns(
687        self, source_columns: t.Dict[str, t.Sequence[str]]
688    ) -> t.Mapping[str, str]:
689        """
690        Find all the unambiguous columns in sources.
691
692        Args:
693            source_columns: Mapping of names to source columns.
694
695        Returns:
696            Mapping of column name to source name.
697        """
698        if not source_columns:
699            return {}
700
701        source_columns_pairs = list(source_columns.items())
702
703        first_table, first_columns = source_columns_pairs[0]
704
705        if len(source_columns_pairs) == 1:
706            # Performance optimization - avoid copying first_columns if there is only one table.
707            return SingleValuedMapping(first_columns, first_table)
708
709        unambiguous_columns = {col: first_table for col in first_columns}
710        all_columns = set(unambiguous_columns)
711
712        for table, columns in source_columns_pairs[1:]:
713            unique = set(columns)
714            ambiguous = all_columns.intersection(unique)
715            all_columns.update(columns)
716
717            for column in ambiguous:
718                unambiguous_columns.pop(column, None)
719            for column in unique.difference(ambiguous):
720                unambiguous_columns[column] = table
721
722        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns(
20    expression: exp.Expression,
21    schema: t.Dict | Schema,
22    expand_alias_refs: bool = True,
23    expand_stars: bool = True,
24    infer_schema: t.Optional[bool] = None,
25) -> exp.Expression:
26    """
27    Rewrite sqlglot AST to have fully qualified columns.
28
29    Example:
30        >>> import sqlglot
31        >>> schema = {"tbl": {"col": "INT"}}
32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
33        >>> qualify_columns(expression, schema).sql()
34        'SELECT tbl.col AS col FROM tbl'
35
36    Args:
37        expression: Expression to qualify.
38        schema: Database schema.
39        expand_alias_refs: Whether to expand references to aliases.
40        expand_stars: Whether to expand star queries. This is a necessary step
41            for most of the optimizer's rules to work; do not set to False unless you
42            know what you're doing!
43        infer_schema: Whether to infer the schema if missing.
44
45    Returns:
46        The qualified expression.
47
48    Notes:
49        - Currently only handles a single PIVOT or UNPIVOT operator
50    """
51    schema = ensure_schema(schema)
52    infer_schema = schema.empty if infer_schema is None else infer_schema
53    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
54
55    for scope in traverse_scope(expression):
56        resolver = Resolver(scope, schema, infer_schema=infer_schema)
57        _pop_table_column_aliases(scope.ctes)
58        _pop_table_column_aliases(scope.derived_tables)
59        using_column_tables = _expand_using(scope, resolver)
60
61        if schema.empty and expand_alias_refs:
62            _expand_alias_refs(scope, resolver)
63
64        _qualify_columns(scope, resolver)
65
66        if not schema.empty and expand_alias_refs:
67            _expand_alias_refs(scope, resolver)
68
69        if not isinstance(scope.expression, exp.UDTF):
70            if expand_stars:
71                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
72            qualify_outputs(scope)
73
74        _expand_group_by(scope)
75        _expand_order_by(scope, resolver)
76
77    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 80def validate_qualify_columns(expression: E) -> E:
 81    """Raise an `OptimizeError` if any columns aren't qualified"""
 82    all_unqualified_columns = []
 83    for scope in traverse_scope(expression):
 84        if isinstance(scope.expression, exp.Select):
 85            unqualified_columns = scope.unqualified_columns
 86
 87            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 88                column = scope.external_columns[0]
 89                for_table = f" for table: '{column.table}'" if column.table else ""
 90                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 91
 92            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 93                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 94                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 95                # this list here to ensure those in the former category will be excluded.
 96                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 97                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 98
 99            all_unqualified_columns.extend(unqualified_columns)
100
101    if all_unqualified_columns:
102        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
103
104    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
510def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
511    """Ensure all output columns are aliased"""
512    if isinstance(scope_or_expression, exp.Expression):
513        scope = build_scope(scope_or_expression)
514        if not isinstance(scope, Scope):
515            return
516    else:
517        scope = scope_or_expression
518
519    new_selections = []
520    for i, (selection, aliased_column) in enumerate(
521        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
522    ):
523        if selection is None:
524            break
525
526        if isinstance(selection, exp.Subquery):
527            if not selection.output_name:
528                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
529        elif not isinstance(selection, exp.Alias) and not selection.is_star:
530            selection = alias(
531                selection,
532                alias=selection.output_name or f"_col_{i}",
533                copy=False,
534            )
535        if aliased_column:
536            selection.set("alias", exp.to_identifier(aliased_column))
537
538        new_selections.append(selection)
539
540    if isinstance(scope.expression, exp.Select):
541        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
544def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
545    """Makes sure all identifiers that need to be quoted are quoted."""
546    return expression.transform(
547        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
548    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
551def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
552    """
553    Pushes down the CTE alias columns into the projection,
554
555    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
556
557    Example:
558        >>> import sqlglot
559        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
560        >>> pushdown_cte_alias_columns(expression).sql()
561        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
562
563    Args:
564        expression: Expression to pushdown.
565
566    Returns:
567        The expression with the CTE aliases pushed down into the projection.
568    """
569    for cte in expression.find_all(exp.CTE):
570        if cte.alias_column_names:
571            new_expressions = []
572            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
573                if isinstance(projection, exp.Alias):
574                    projection.set("alias", _alias)
575                else:
576                    projection = alias(projection, alias=_alias)
577                new_expressions.append(projection)
578            cte.this.set("expressions", new_expressions)
579
580    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
583class Resolver:
584    """
585    Helper for resolving columns.
586
587    This is a class so we can lazily load some things and easily share them across functions.
588    """
589
590    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
591        self.scope = scope
592        self.schema = schema
593        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
594        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
595        self._all_columns: t.Optional[t.Set[str]] = None
596        self._infer_schema = infer_schema
597
598    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
599        """
600        Get the table for a column name.
601
602        Args:
603            column_name: The column name to find the table for.
604        Returns:
605            The table name if it can be found/inferred.
606        """
607        if self._unambiguous_columns is None:
608            self._unambiguous_columns = self._get_unambiguous_columns(
609                self._get_all_source_columns()
610            )
611
612        table_name = self._unambiguous_columns.get(column_name)
613
614        if not table_name and self._infer_schema:
615            sources_without_schema = tuple(
616                source
617                for source, columns in self._get_all_source_columns().items()
618                if not columns or "*" in columns
619            )
620            if len(sources_without_schema) == 1:
621                table_name = sources_without_schema[0]
622
623        if table_name not in self.scope.selected_sources:
624            return exp.to_identifier(table_name)
625
626        node, _ = self.scope.selected_sources.get(table_name)
627
628        if isinstance(node, exp.Query):
629            while node and node.alias != table_name:
630                node = node.parent
631
632        node_alias = node.args.get("alias")
633        if node_alias:
634            return exp.to_identifier(node_alias.this)
635
636        return exp.to_identifier(table_name)
637
638    @property
639    def all_columns(self) -> t.Set[str]:
640        """All available columns of all sources in this scope"""
641        if self._all_columns is None:
642            self._all_columns = {
643                column for columns in self._get_all_source_columns().values() for column in columns
644            }
645        return self._all_columns
646
647    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
648        """Resolve the source columns for a given source `name`."""
649        if name not in self.scope.sources:
650            raise OptimizeError(f"Unknown table: {name}")
651
652        source = self.scope.sources[name]
653
654        if isinstance(source, exp.Table):
655            columns = self.schema.column_names(source, only_visible)
656        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
657            columns = source.expression.alias_column_names
658        else:
659            columns = source.expression.named_selects
660
661        node, _ = self.scope.selected_sources.get(name) or (None, None)
662        if isinstance(node, Scope):
663            column_aliases = node.expression.alias_column_names
664        elif isinstance(node, exp.Expression):
665            column_aliases = node.alias_column_names
666        else:
667            column_aliases = []
668
669        if column_aliases:
670            # If the source's columns are aliased, their aliases shadow the corresponding column names.
671            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
672            return [
673                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
674            ]
675        return columns
676
677    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
678        if self._source_columns is None:
679            self._source_columns = {
680                source_name: self.get_source_columns(source_name)
681                for source_name, source in itertools.chain(
682                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
683                )
684            }
685        return self._source_columns
686
687    def _get_unambiguous_columns(
688        self, source_columns: t.Dict[str, t.Sequence[str]]
689    ) -> t.Mapping[str, str]:
690        """
691        Find all the unambiguous columns in sources.
692
693        Args:
694            source_columns: Mapping of names to source columns.
695
696        Returns:
697            Mapping of column name to source name.
698        """
699        if not source_columns:
700            return {}
701
702        source_columns_pairs = list(source_columns.items())
703
704        first_table, first_columns = source_columns_pairs[0]
705
706        if len(source_columns_pairs) == 1:
707            # Performance optimization - avoid copying first_columns if there is only one table.
708            return SingleValuedMapping(first_columns, first_table)
709
710        unambiguous_columns = {col: first_table for col in first_columns}
711        all_columns = set(unambiguous_columns)
712
713        for table, columns in source_columns_pairs[1:]:
714            unique = set(columns)
715            ambiguous = all_columns.intersection(unique)
716            all_columns.update(columns)
717
718            for column in ambiguous:
719                unambiguous_columns.pop(column, None)
720            for column in unique.difference(ambiguous):
721                unambiguous_columns[column] = table
722
723        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
590    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
591        self.scope = scope
592        self.schema = schema
593        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
594        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
595        self._all_columns: t.Optional[t.Set[str]] = None
596        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
598    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
599        """
600        Get the table for a column name.
601
602        Args:
603            column_name: The column name to find the table for.
604        Returns:
605            The table name if it can be found/inferred.
606        """
607        if self._unambiguous_columns is None:
608            self._unambiguous_columns = self._get_unambiguous_columns(
609                self._get_all_source_columns()
610            )
611
612        table_name = self._unambiguous_columns.get(column_name)
613
614        if not table_name and self._infer_schema:
615            sources_without_schema = tuple(
616                source
617                for source, columns in self._get_all_source_columns().items()
618                if not columns or "*" in columns
619            )
620            if len(sources_without_schema) == 1:
621                table_name = sources_without_schema[0]
622
623        if table_name not in self.scope.selected_sources:
624            return exp.to_identifier(table_name)
625
626        node, _ = self.scope.selected_sources.get(table_name)
627
628        if isinstance(node, exp.Query):
629            while node and node.alias != table_name:
630                node = node.parent
631
632        node_alias = node.args.get("alias")
633        if node_alias:
634            return exp.to_identifier(node_alias.this)
635
636        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
638    @property
639    def all_columns(self) -> t.Set[str]:
640        """All available columns of all sources in this scope"""
641        if self._all_columns is None:
642            self._all_columns = {
643                column for columns in self._get_all_source_columns().values() for column in columns
644            }
645        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
647    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
648        """Resolve the source columns for a given source `name`."""
649        if name not in self.scope.sources:
650            raise OptimizeError(f"Unknown table: {name}")
651
652        source = self.scope.sources[name]
653
654        if isinstance(source, exp.Table):
655            columns = self.schema.column_names(source, only_visible)
656        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
657            columns = source.expression.alias_column_names
658        else:
659            columns = source.expression.named_selects
660
661        node, _ = self.scope.selected_sources.get(name) or (None, None)
662        if isinstance(node, Scope):
663            column_aliases = node.expression.alias_column_names
664        elif isinstance(node, exp.Expression):
665            column_aliases = node.alias_column_names
666        else:
667            column_aliases = []
668
669        if column_aliases:
670            # If the source's columns are aliased, their aliases shadow the corresponding column names.
671            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
672            return [
673                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
674            ]
675        return columns

Resolve the source columns for a given source name.