Edit on GitHub

sqlglot.optimizer.resolver

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.scope import Scope
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.schema import Schema
 14    from collections.abc import Sequence, Mapping
 15
 16
 17class Resolver:
 18    """
 19    Helper for resolving columns.
 20
 21    This is a class so we can lazily load some things and easily share them across functions.
 22    """
 23
 24    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 25        self.scope = scope
 26        self.schema = schema
 27        self.dialect = schema.dialect or Dialect()
 28        self._source_columns: dict[str, Sequence[str]] | None = None
 29        self._unambiguous_columns: Mapping[str, str] | None = None
 30        self._all_columns: set[str] | None = None
 31        self._infer_schema = infer_schema
 32        self._get_source_columns_cache: dict[tuple[str, bool], Sequence[str]] = {}
 33
 34    def get_table(self, column: str | exp.Column) -> exp.Identifier | None:
 35        """
 36        Get the table for a column name.
 37
 38        Args:
 39            column: The column expression (or column name) to find the table for.
 40        Returns:
 41            The table name if it can be found/inferred.
 42        """
 43        column_name = column if isinstance(column, str) else column.name
 44
 45        table_name = self._get_table_name_from_sources(column_name)
 46
 47        if not table_name and isinstance(column, exp.Column):
 48            # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
 49            # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
 50            # we may be able to disambiguate based on the source order.
 51            if join_context := self._get_column_join_context(column):
 52                # In this case, the return value will be the join that _may_ be able to disambiguate the column
 53                # and we can use the source columns available at that join to get the table name
 54                # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
 55                try:
 56                    table_name = self._get_table_name_from_sources(
 57                        column_name, self._get_available_source_columns(join_context)
 58                    )
 59                except OptimizeError:
 60                    pass
 61
 62        if not table_name and self._infer_schema:
 63            sources_without_schema = tuple(
 64                source
 65                for source, columns in self._get_all_source_columns().items()
 66                if not columns or "*" in columns
 67            )
 68            if len(sources_without_schema) == 1:
 69                table_name = sources_without_schema[0]
 70
 71        if table_name not in self.scope.selected_sources:
 72            return exp.to_identifier(table_name)
 73
 74        node: exp.Expr = self.scope.selected_sources[table_name][0]
 75
 76        if isinstance(node, exp.Query):
 77            while node and node.alias != table_name and node.parent:
 78                node = node.parent
 79
 80        node_alias = node.args.get("alias")
 81        if node_alias:
 82            return exp.to_identifier(node_alias.this)
 83
 84        return exp.to_identifier(table_name)
 85
 86    @property
 87    def all_columns(self) -> set[str]:
 88        """All available columns of all sources in this scope"""
 89        if self._all_columns is None:
 90            self._all_columns = {
 91                column for columns in self._get_all_source_columns().values() for column in columns
 92            }
 93        return self._all_columns
 94
 95    def get_source_columns_from_set_op(self, expression: exp.Expr) -> list[str]:
 96        if isinstance(expression, exp.Select):
 97            return expression.named_selects
 98        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 99            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
100            return self.get_source_columns_from_set_op(expression.this)
101        if not isinstance(expression, exp.SetOperation):
102            raise OptimizeError(f"Unknown set operation: {expression}")
103
104        set_op = expression
105
106        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
107        on_column_list = set_op.args.get("on")
108
109        if on_column_list:
110            # The resulting columns are the columns in the ON clause:
111            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
112            columns = [col.name for col in on_column_list]
113        elif set_op.side or set_op.kind:
114            side = set_op.side
115            kind = set_op.kind
116
117            # Visit the children UNIONs (if any) in a post-order traversal
118            left = self.get_source_columns_from_set_op(set_op.left)
119            right = self.get_source_columns_from_set_op(set_op.right)
120
121            # We use dict.fromkeys to deduplicate keys and maintain insertion order
122            if side == "LEFT":
123                columns = left
124            elif side == "FULL":
125                columns = list(dict.fromkeys(left + right))
126            elif kind == "INNER":
127                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
128        else:
129            columns = set_op.named_selects
130
131        return columns
132
133    def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
134        """Resolve the source columns for a given source `name`."""
135        cache_key = (name, only_visible)
136        if cache_key not in self._get_source_columns_cache:
137            if name not in self.scope.sources:
138                raise OptimizeError(f"Unknown table: {name}")
139
140            source = self.scope.sources[name]
141
142            if isinstance(source, exp.Table):
143                columns = self.schema.column_names(source, only_visible)
144            elif isinstance(source, Scope) and isinstance(
145                source.expression, (exp.Values, exp.Unnest)
146            ):
147                columns = source.expression.named_selects
148
149                # in bigquery, unnest structs are automatically scoped as tables, so you can
150                # directly select a struct field in a query.
151                # this handles the case where the unnest is statically defined.
152                if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
153                    unnest = source.expression
154
155                    # if type is not annotated yet, try to get it from the schema
156                    if not unnest.type or unnest.type.is_type(exp.DType.UNKNOWN):
157                        unnest_expr = seq_get(unnest.expressions, 0)
158                        if isinstance(unnest_expr, exp.Column) and self.scope.parent:
159                            col_type = self._get_unnest_column_type(unnest_expr)
160                            # extract element type if it's an ARRAY
161                            if col_type and col_type.is_type(exp.DType.ARRAY):
162                                element_types = col_type.expressions
163                                if element_types:
164                                    unnest.type = element_types[0].copy()
165                            else:
166                                if col_type:
167                                    unnest.type = col_type.copy()
168                    # check if the result type is a STRUCT - extract struct field names
169                    if unnest.is_type(exp.DType.STRUCT):
170                        for k in unnest.type.expressions:  # type: ignore
171                            columns.append(k.name)
172            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
173                columns = self.get_source_columns_from_set_op(source.expression)
174            else:
175                selectable = source.expression.assert_is(exp.Selectable)
176                select = seq_get(selectable.selects, 0)
177
178                if isinstance(select, exp.QueryTransform):
179                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
180                    schema = select.args.get("schema")
181                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
182                else:
183                    columns = selectable.named_selects
184
185            node, _ = self.scope.selected_sources.get(name) or (None, None)
186            if isinstance(node, Scope):
187                column_aliases = node.expression.alias_column_names
188            elif isinstance(node, exp.Expr):
189                column_aliases = node.alias_column_names
190            else:
191                column_aliases = []
192
193            if column_aliases:
194                # If the source's columns are aliased, their aliases shadow the corresponding column names.
195                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
196                columns = [
197                    alias or name
198                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
199                ]
200
201            self._get_source_columns_cache[cache_key] = columns
202
203        return self._get_source_columns_cache[cache_key]
204
205    def _get_all_source_columns(self) -> dict[str, Sequence[str]]:
206        if self._source_columns is None:
207            self._source_columns = {
208                source_name: self.get_source_columns(source_name)
209                for source_name, source in itertools.chain(
210                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
211                )
212            }
213        return self._source_columns
214
215    def _get_table_name_from_sources(
216        self, column_name: str, source_columns: dict[str, Sequence[str]] | None = None
217    ) -> str | None:
218        if not source_columns:
219            # If not supplied, get all sources to calculate unambiguous columns
220            if self._unambiguous_columns is None:
221                self._unambiguous_columns = self._get_unambiguous_columns(
222                    self._get_all_source_columns()
223                )
224
225            unambiguous_columns = self._unambiguous_columns
226        else:
227            unambiguous_columns = self._get_unambiguous_columns(source_columns)
228
229        return unambiguous_columns.get(column_name)
230
231    def _get_column_join_context(self, column: exp.Column) -> exp.Join | None:
232        """
233        Check if a column participating in a join can be qualified based on the source order.
234        """
235        args = self.scope.expression.args
236        joins = args.get("joins")
237
238        if not joins or args.get("laterals") or args.get("pivots"):
239            # Feature gap: We currently don't try to disambiguate columns if other sources
240            # (e.g laterals, pivots) exist alongside joins
241            return None
242
243        join_ancestor = column.find_ancestor(exp.Join, exp.Select)
244
245        if (
246            isinstance(join_ancestor, exp.Join)
247            and join_ancestor.alias_or_name in self.scope.selected_sources
248        ):
249            # Ensure that the found ancestor is a join that contains an actual source,
250            # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
251            return join_ancestor
252
253        return None
254
255    def _get_available_source_columns(self, join_ancestor: exp.Join) -> dict[str, Sequence[str]]:
256        """
257        Get the source columns that are available at the point where a column is referenced.
258
259        For columns in JOIN conditions, this only includes tables that have been joined
260        up to that point. Example:
261
262        ```
263        SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
264        ```                                                        ^
265                                                                   |
266                                +----------------------------------+
267                                |
268                                ⌄
269        The unqualified column `c` is not ambiguous if no other sources up until that
270        join i.e t_1, ..., t_n, contain a column named `c`.
271
272        """
273        args = self.scope.expression.args
274
275        # Collect tables in order: FROM clause tables + joined tables up to current join
276        from_name = args["from_"].alias_or_name
277        available_sources = {from_name: self.get_source_columns(from_name)}
278
279        for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
280            available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
281
282        return available_sources
283
284    def _get_unambiguous_columns(
285        self, source_columns: dict[str, Sequence[str]]
286    ) -> Mapping[str, str]:
287        """
288        Find all the unambiguous columns in sources.
289
290        Args:
291            source_columns: Mapping of names to source columns.
292
293        Returns:
294            Mapping of column name to source name.
295        """
296        if not source_columns:
297            return {}
298
299        source_columns_pairs = list(source_columns.items())
300
301        first_table, first_columns = source_columns_pairs[0]
302
303        if len(source_columns_pairs) == 1:
304            # Performance optimization - avoid copying first_columns if there is only one table.
305            return SingleValuedMapping(first_columns, first_table)
306
307        # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases
308        # from alias.columns[0] to their source names. This is used to resolve shadowing
309        # where an UNNEST alias shadows a column name from another table.
310        unnest_original_aliases: dict[str, str] = {}
311        if self.dialect.UNNEST_COLUMN_ONLY:
312            unnest_original_aliases = {
313                alias_arg.columns[0].name: source_name
314                for source_name, source in self.scope.sources.items()
315                if (
316                    isinstance(source.expression, exp.Unnest)
317                    and (alias_arg := source.expression.args.get("alias"))
318                    and alias_arg.columns
319                )
320            }
321
322        unambiguous_columns = {col: first_table for col in first_columns}
323        all_columns = set(unambiguous_columns)
324
325        for table, columns in source_columns_pairs[1:]:
326            unique = set(columns)
327            ambiguous = all_columns.intersection(unique)
328            all_columns.update(columns)
329
330            for column in ambiguous:
331                if column in unnest_original_aliases:
332                    unambiguous_columns[column] = unnest_original_aliases[column]
333                    continue
334
335                unambiguous_columns.pop(column, None)
336            for column in unique.difference(ambiguous):
337                unambiguous_columns[column] = table
338
339        return unambiguous_columns
340
341    def _get_unnest_column_type(self, column: exp.Column) -> exp.DataType | None:
342        """
343        Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
344
345        Args:
346            column: The column expression being unnested.
347
348        Returns:
349            The DataType of the column, or None if not found.
350        """
351        scope = self.scope.parent
352        assert scope
353
354        # if column is qualified, use that table, otherwise disambiguate using the resolver
355        if column.table:
356            table_name = column.table
357        else:
358            # use the parent scope's resolver to disambiguate the column
359            parent_resolver = Resolver(scope, self.schema, self._infer_schema)
360            table_identifier = parent_resolver.get_table(column)
361            if not table_identifier:
362                return None
363            table_name = table_identifier.name
364
365        source = scope.sources.get(table_name)
366        return self._get_column_type_from_scope(source, column) if source else None
367
368    def _get_column_type_from_scope(
369        self, source: Scope | exp.Table, column: exp.Column
370    ) -> exp.DataType | None:
371        """
372        Get a column's type by tracing through scopes/tables to find the base table.
373
374        Args:
375            source: The source to search - can be a Scope (to iterate its sources) or a Table.
376            column: The column to find the type for.
377
378        Returns:
379            The DataType of the column, or None if not found.
380        """
381        if isinstance(source, exp.Table):
382            # base table - get the column type from schema
383            col_type: exp.DataType | None = self.schema.get_column_type(source, column)
384            if col_type and not col_type.is_type(exp.DType.UNKNOWN):
385                return col_type
386        elif isinstance(source, Scope):
387            # iterate over all sources in the scope
388            for source_name, nested_source in source.sources.items():
389                col_type = self._get_column_type_from_scope(nested_source, column)
390                if col_type and not col_type.is_type(exp.DType.UNKNOWN):
391                    return col_type
392
393        return None
class Resolver:
 18class Resolver:
 19    """
 20    Helper for resolving columns.
 21
 22    This is a class so we can lazily load some things and easily share them across functions.
 23    """
 24
 25    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 26        self.scope = scope
 27        self.schema = schema
 28        self.dialect = schema.dialect or Dialect()
 29        self._source_columns: dict[str, Sequence[str]] | None = None
 30        self._unambiguous_columns: Mapping[str, str] | None = None
 31        self._all_columns: set[str] | None = None
 32        self._infer_schema = infer_schema
 33        self._get_source_columns_cache: dict[tuple[str, bool], Sequence[str]] = {}
 34
 35    def get_table(self, column: str | exp.Column) -> exp.Identifier | None:
 36        """
 37        Get the table for a column name.
 38
 39        Args:
 40            column: The column expression (or column name) to find the table for.
 41        Returns:
 42            The table name if it can be found/inferred.
 43        """
 44        column_name = column if isinstance(column, str) else column.name
 45
 46        table_name = self._get_table_name_from_sources(column_name)
 47
 48        if not table_name and isinstance(column, exp.Column):
 49            # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
 50            # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
 51            # we may be able to disambiguate based on the source order.
 52            if join_context := self._get_column_join_context(column):
 53                # In this case, the return value will be the join that _may_ be able to disambiguate the column
 54                # and we can use the source columns available at that join to get the table name
 55                # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
 56                try:
 57                    table_name = self._get_table_name_from_sources(
 58                        column_name, self._get_available_source_columns(join_context)
 59                    )
 60                except OptimizeError:
 61                    pass
 62
 63        if not table_name and self._infer_schema:
 64            sources_without_schema = tuple(
 65                source
 66                for source, columns in self._get_all_source_columns().items()
 67                if not columns or "*" in columns
 68            )
 69            if len(sources_without_schema) == 1:
 70                table_name = sources_without_schema[0]
 71
 72        if table_name not in self.scope.selected_sources:
 73            return exp.to_identifier(table_name)
 74
 75        node: exp.Expr = self.scope.selected_sources[table_name][0]
 76
 77        if isinstance(node, exp.Query):
 78            while node and node.alias != table_name and node.parent:
 79                node = node.parent
 80
 81        node_alias = node.args.get("alias")
 82        if node_alias:
 83            return exp.to_identifier(node_alias.this)
 84
 85        return exp.to_identifier(table_name)
 86
 87    @property
 88    def all_columns(self) -> set[str]:
 89        """All available columns of all sources in this scope"""
 90        if self._all_columns is None:
 91            self._all_columns = {
 92                column for columns in self._get_all_source_columns().values() for column in columns
 93            }
 94        return self._all_columns
 95
 96    def get_source_columns_from_set_op(self, expression: exp.Expr) -> list[str]:
 97        if isinstance(expression, exp.Select):
 98            return expression.named_selects
 99        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
100            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
101            return self.get_source_columns_from_set_op(expression.this)
102        if not isinstance(expression, exp.SetOperation):
103            raise OptimizeError(f"Unknown set operation: {expression}")
104
105        set_op = expression
106
107        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
108        on_column_list = set_op.args.get("on")
109
110        if on_column_list:
111            # The resulting columns are the columns in the ON clause:
112            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
113            columns = [col.name for col in on_column_list]
114        elif set_op.side or set_op.kind:
115            side = set_op.side
116            kind = set_op.kind
117
118            # Visit the children UNIONs (if any) in a post-order traversal
119            left = self.get_source_columns_from_set_op(set_op.left)
120            right = self.get_source_columns_from_set_op(set_op.right)
121
122            # We use dict.fromkeys to deduplicate keys and maintain insertion order
123            if side == "LEFT":
124                columns = left
125            elif side == "FULL":
126                columns = list(dict.fromkeys(left + right))
127            elif kind == "INNER":
128                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
129        else:
130            columns = set_op.named_selects
131
132        return columns
133
134    def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
135        """Resolve the source columns for a given source `name`."""
136        cache_key = (name, only_visible)
137        if cache_key not in self._get_source_columns_cache:
138            if name not in self.scope.sources:
139                raise OptimizeError(f"Unknown table: {name}")
140
141            source = self.scope.sources[name]
142
143            if isinstance(source, exp.Table):
144                columns = self.schema.column_names(source, only_visible)
145            elif isinstance(source, Scope) and isinstance(
146                source.expression, (exp.Values, exp.Unnest)
147            ):
148                columns = source.expression.named_selects
149
150                # in bigquery, unnest structs are automatically scoped as tables, so you can
151                # directly select a struct field in a query.
152                # this handles the case where the unnest is statically defined.
153                if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
154                    unnest = source.expression
155
156                    # if type is not annotated yet, try to get it from the schema
157                    if not unnest.type or unnest.type.is_type(exp.DType.UNKNOWN):
158                        unnest_expr = seq_get(unnest.expressions, 0)
159                        if isinstance(unnest_expr, exp.Column) and self.scope.parent:
160                            col_type = self._get_unnest_column_type(unnest_expr)
161                            # extract element type if it's an ARRAY
162                            if col_type and col_type.is_type(exp.DType.ARRAY):
163                                element_types = col_type.expressions
164                                if element_types:
165                                    unnest.type = element_types[0].copy()
166                            else:
167                                if col_type:
168                                    unnest.type = col_type.copy()
169                    # check if the result type is a STRUCT - extract struct field names
170                    if unnest.is_type(exp.DType.STRUCT):
171                        for k in unnest.type.expressions:  # type: ignore
172                            columns.append(k.name)
173            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
174                columns = self.get_source_columns_from_set_op(source.expression)
175            else:
176                selectable = source.expression.assert_is(exp.Selectable)
177                select = seq_get(selectable.selects, 0)
178
179                if isinstance(select, exp.QueryTransform):
180                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
181                    schema = select.args.get("schema")
182                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
183                else:
184                    columns = selectable.named_selects
185
186            node, _ = self.scope.selected_sources.get(name) or (None, None)
187            if isinstance(node, Scope):
188                column_aliases = node.expression.alias_column_names
189            elif isinstance(node, exp.Expr):
190                column_aliases = node.alias_column_names
191            else:
192                column_aliases = []
193
194            if column_aliases:
195                # If the source's columns are aliased, their aliases shadow the corresponding column names.
196                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
197                columns = [
198                    alias or name
199                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
200                ]
201
202            self._get_source_columns_cache[cache_key] = columns
203
204        return self._get_source_columns_cache[cache_key]
205
206    def _get_all_source_columns(self) -> dict[str, Sequence[str]]:
207        if self._source_columns is None:
208            self._source_columns = {
209                source_name: self.get_source_columns(source_name)
210                for source_name, source in itertools.chain(
211                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
212                )
213            }
214        return self._source_columns
215
216    def _get_table_name_from_sources(
217        self, column_name: str, source_columns: dict[str, Sequence[str]] | None = None
218    ) -> str | None:
219        if not source_columns:
220            # If not supplied, get all sources to calculate unambiguous columns
221            if self._unambiguous_columns is None:
222                self._unambiguous_columns = self._get_unambiguous_columns(
223                    self._get_all_source_columns()
224                )
225
226            unambiguous_columns = self._unambiguous_columns
227        else:
228            unambiguous_columns = self._get_unambiguous_columns(source_columns)
229
230        return unambiguous_columns.get(column_name)
231
232    def _get_column_join_context(self, column: exp.Column) -> exp.Join | None:
233        """
234        Check if a column participating in a join can be qualified based on the source order.
235        """
236        args = self.scope.expression.args
237        joins = args.get("joins")
238
239        if not joins or args.get("laterals") or args.get("pivots"):
240            # Feature gap: We currently don't try to disambiguate columns if other sources
241            # (e.g laterals, pivots) exist alongside joins
242            return None
243
244        join_ancestor = column.find_ancestor(exp.Join, exp.Select)
245
246        if (
247            isinstance(join_ancestor, exp.Join)
248            and join_ancestor.alias_or_name in self.scope.selected_sources
249        ):
250            # Ensure that the found ancestor is a join that contains an actual source,
251            # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
252            return join_ancestor
253
254        return None
255
256    def _get_available_source_columns(self, join_ancestor: exp.Join) -> dict[str, Sequence[str]]:
257        """
258        Get the source columns that are available at the point where a column is referenced.
259
260        For columns in JOIN conditions, this only includes tables that have been joined
261        up to that point. Example:
262
263        ```
264        SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
265        ```                                                        ^
266                                                                   |
267                                +----------------------------------+
268                                |
269                                ⌄
270        The unqualified column `c` is not ambiguous if no other sources up until that
271        join i.e t_1, ..., t_n, contain a column named `c`.
272
273        """
274        args = self.scope.expression.args
275
276        # Collect tables in order: FROM clause tables + joined tables up to current join
277        from_name = args["from_"].alias_or_name
278        available_sources = {from_name: self.get_source_columns(from_name)}
279
280        for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
281            available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
282
283        return available_sources
284
285    def _get_unambiguous_columns(
286        self, source_columns: dict[str, Sequence[str]]
287    ) -> Mapping[str, str]:
288        """
289        Find all the unambiguous columns in sources.
290
291        Args:
292            source_columns: Mapping of names to source columns.
293
294        Returns:
295            Mapping of column name to source name.
296        """
297        if not source_columns:
298            return {}
299
300        source_columns_pairs = list(source_columns.items())
301
302        first_table, first_columns = source_columns_pairs[0]
303
304        if len(source_columns_pairs) == 1:
305            # Performance optimization - avoid copying first_columns if there is only one table.
306            return SingleValuedMapping(first_columns, first_table)
307
308        # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases
309        # from alias.columns[0] to their source names. This is used to resolve shadowing
310        # where an UNNEST alias shadows a column name from another table.
311        unnest_original_aliases: dict[str, str] = {}
312        if self.dialect.UNNEST_COLUMN_ONLY:
313            unnest_original_aliases = {
314                alias_arg.columns[0].name: source_name
315                for source_name, source in self.scope.sources.items()
316                if (
317                    isinstance(source.expression, exp.Unnest)
318                    and (alias_arg := source.expression.args.get("alias"))
319                    and alias_arg.columns
320                )
321            }
322
323        unambiguous_columns = {col: first_table for col in first_columns}
324        all_columns = set(unambiguous_columns)
325
326        for table, columns in source_columns_pairs[1:]:
327            unique = set(columns)
328            ambiguous = all_columns.intersection(unique)
329            all_columns.update(columns)
330
331            for column in ambiguous:
332                if column in unnest_original_aliases:
333                    unambiguous_columns[column] = unnest_original_aliases[column]
334                    continue
335
336                unambiguous_columns.pop(column, None)
337            for column in unique.difference(ambiguous):
338                unambiguous_columns[column] = table
339
340        return unambiguous_columns
341
342    def _get_unnest_column_type(self, column: exp.Column) -> exp.DataType | None:
343        """
344        Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
345
346        Args:
347            column: The column expression being unnested.
348
349        Returns:
350            The DataType of the column, or None if not found.
351        """
352        scope = self.scope.parent
353        assert scope
354
355        # if column is qualified, use that table, otherwise disambiguate using the resolver
356        if column.table:
357            table_name = column.table
358        else:
359            # use the parent scope's resolver to disambiguate the column
360            parent_resolver = Resolver(scope, self.schema, self._infer_schema)
361            table_identifier = parent_resolver.get_table(column)
362            if not table_identifier:
363                return None
364            table_name = table_identifier.name
365
366        source = scope.sources.get(table_name)
367        return self._get_column_type_from_scope(source, column) if source else None
368
369    def _get_column_type_from_scope(
370        self, source: Scope | exp.Table, column: exp.Column
371    ) -> exp.DataType | None:
372        """
373        Get a column's type by tracing through scopes/tables to find the base table.
374
375        Args:
376            source: The source to search - can be a Scope (to iterate its sources) or a Table.
377            column: The column to find the type for.
378
379        Returns:
380            The DataType of the column, or None if not found.
381        """
382        if isinstance(source, exp.Table):
383            # base table - get the column type from schema
384            col_type: exp.DataType | None = self.schema.get_column_type(source, column)
385            if col_type and not col_type.is_type(exp.DType.UNKNOWN):
386                return col_type
387        elif isinstance(source, Scope):
388            # iterate over all sources in the scope
389            for source_name, nested_source in source.sources.items():
390                col_type = self._get_column_type_from_scope(nested_source, column)
391                if col_type and not col_type.is_type(exp.DType.UNKNOWN):
392                    return col_type
393
394        return None

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
25    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
26        self.scope = scope
27        self.schema = schema
28        self.dialect = schema.dialect or Dialect()
29        self._source_columns: dict[str, Sequence[str]] | None = None
30        self._unambiguous_columns: Mapping[str, str] | None = None
31        self._all_columns: set[str] | None = None
32        self._infer_schema = infer_schema
33        self._get_source_columns_cache: dict[tuple[str, bool], Sequence[str]] = {}
scope
schema
dialect
def get_table( self, column: str | sqlglot.expressions.core.Column) -> sqlglot.expressions.core.Identifier | None:
35    def get_table(self, column: str | exp.Column) -> exp.Identifier | None:
36        """
37        Get the table for a column name.
38
39        Args:
40            column: The column expression (or column name) to find the table for.
41        Returns:
42            The table name if it can be found/inferred.
43        """
44        column_name = column if isinstance(column, str) else column.name
45
46        table_name = self._get_table_name_from_sources(column_name)
47
48        if not table_name and isinstance(column, exp.Column):
49            # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
50            # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
51            # we may be able to disambiguate based on the source order.
52            if join_context := self._get_column_join_context(column):
53                # In this case, the return value will be the join that _may_ be able to disambiguate the column
54                # and we can use the source columns available at that join to get the table name
55                # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
56                try:
57                    table_name = self._get_table_name_from_sources(
58                        column_name, self._get_available_source_columns(join_context)
59                    )
60                except OptimizeError:
61                    pass
62
63        if not table_name and self._infer_schema:
64            sources_without_schema = tuple(
65                source
66                for source, columns in self._get_all_source_columns().items()
67                if not columns or "*" in columns
68            )
69            if len(sources_without_schema) == 1:
70                table_name = sources_without_schema[0]
71
72        if table_name not in self.scope.selected_sources:
73            return exp.to_identifier(table_name)
74
75        node: exp.Expr = self.scope.selected_sources[table_name][0]
76
77        if isinstance(node, exp.Query):
78            while node and node.alias != table_name and node.parent:
79                node = node.parent
80
81        node_alias = node.args.get("alias")
82        if node_alias:
83            return exp.to_identifier(node_alias.this)
84
85        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column: The column expression (or column name) to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: set[str]
87    @property
88    def all_columns(self) -> set[str]:
89        """All available columns of all sources in this scope"""
90        if self._all_columns is None:
91            self._all_columns = {
92                column for columns in self._get_all_source_columns().values() for column in columns
93            }
94        return self._all_columns

All available columns of all sources in this scope

def get_source_columns_from_set_op(self, expression: sqlglot.expressions.core.Expr) -> list[str]:
 96    def get_source_columns_from_set_op(self, expression: exp.Expr) -> list[str]:
 97        if isinstance(expression, exp.Select):
 98            return expression.named_selects
 99        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
100            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
101            return self.get_source_columns_from_set_op(expression.this)
102        if not isinstance(expression, exp.SetOperation):
103            raise OptimizeError(f"Unknown set operation: {expression}")
104
105        set_op = expression
106
107        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
108        on_column_list = set_op.args.get("on")
109
110        if on_column_list:
111            # The resulting columns are the columns in the ON clause:
112            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
113            columns = [col.name for col in on_column_list]
114        elif set_op.side or set_op.kind:
115            side = set_op.side
116            kind = set_op.kind
117
118            # Visit the children UNIONs (if any) in a post-order traversal
119            left = self.get_source_columns_from_set_op(set_op.left)
120            right = self.get_source_columns_from_set_op(set_op.right)
121
122            # We use dict.fromkeys to deduplicate keys and maintain insertion order
123            if side == "LEFT":
124                columns = left
125            elif side == "FULL":
126                columns = list(dict.fromkeys(left + right))
127            elif kind == "INNER":
128                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
129        else:
130            columns = set_op.named_selects
131
132        return columns
def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
134    def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
135        """Resolve the source columns for a given source `name`."""
136        cache_key = (name, only_visible)
137        if cache_key not in self._get_source_columns_cache:
138            if name not in self.scope.sources:
139                raise OptimizeError(f"Unknown table: {name}")
140
141            source = self.scope.sources[name]
142
143            if isinstance(source, exp.Table):
144                columns = self.schema.column_names(source, only_visible)
145            elif isinstance(source, Scope) and isinstance(
146                source.expression, (exp.Values, exp.Unnest)
147            ):
148                columns = source.expression.named_selects
149
150                # in bigquery, unnest structs are automatically scoped as tables, so you can
151                # directly select a struct field in a query.
152                # this handles the case where the unnest is statically defined.
153                if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
154                    unnest = source.expression
155
156                    # if type is not annotated yet, try to get it from the schema
157                    if not unnest.type or unnest.type.is_type(exp.DType.UNKNOWN):
158                        unnest_expr = seq_get(unnest.expressions, 0)
159                        if isinstance(unnest_expr, exp.Column) and self.scope.parent:
160                            col_type = self._get_unnest_column_type(unnest_expr)
161                            # extract element type if it's an ARRAY
162                            if col_type and col_type.is_type(exp.DType.ARRAY):
163                                element_types = col_type.expressions
164                                if element_types:
165                                    unnest.type = element_types[0].copy()
166                            else:
167                                if col_type:
168                                    unnest.type = col_type.copy()
169                    # check if the result type is a STRUCT - extract struct field names
170                    if unnest.is_type(exp.DType.STRUCT):
171                        for k in unnest.type.expressions:  # type: ignore
172                            columns.append(k.name)
173            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
174                columns = self.get_source_columns_from_set_op(source.expression)
175            else:
176                selectable = source.expression.assert_is(exp.Selectable)
177                select = seq_get(selectable.selects, 0)
178
179                if isinstance(select, exp.QueryTransform):
180                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
181                    schema = select.args.get("schema")
182                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
183                else:
184                    columns = selectable.named_selects
185
186            node, _ = self.scope.selected_sources.get(name) or (None, None)
187            if isinstance(node, Scope):
188                column_aliases = node.expression.alias_column_names
189            elif isinstance(node, exp.Expr):
190                column_aliases = node.alias_column_names
191            else:
192                column_aliases = []
193
194            if column_aliases:
195                # If the source's columns are aliased, their aliases shadow the corresponding column names.
196                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
197                columns = [
198                    alias or name
199                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
200                ]
201
202            self._get_source_columns_cache[cache_key] = columns
203
204        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.