Edit on GitHub

sqlglot.optimizer.resolver

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.dialects.dialect import Dialect
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.scope import Scope
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.schema import Schema
 14
 15
 16class Resolver:
 17    """
 18    Helper for resolving columns.
 19
 20    This is a class so we can lazily load some things and easily share them across functions.
 21    """
 22
 23    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 24        self.scope = scope
 25        self.schema = schema
 26        self.dialect = schema.dialect or Dialect()
 27        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 28        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 29        self._all_columns: t.Optional[t.Set[str]] = None
 30        self._infer_schema = infer_schema
 31        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 32
 33    def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
 34        """
 35        Get the table for a column name.
 36
 37        Args:
 38            column: The column expression (or column name) to find the table for.
 39        Returns:
 40            The table name if it can be found/inferred.
 41        """
 42        column_name = column if isinstance(column, str) else column.name
 43
 44        table_name = self._get_table_name_from_sources(column_name)
 45
 46        if not table_name and isinstance(column, exp.Column):
 47            # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
 48            # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
 49            # we may be able to disambiguate based on the source order.
 50            if join_context := self._get_column_join_context(column):
 51                # In this case, the return value will be the join that _may_ be able to disambiguate the column
 52                # and we can use the source columns available at that join to get the table name
 53                # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
 54                try:
 55                    table_name = self._get_table_name_from_sources(
 56                        column_name, self._get_available_source_columns(join_context)
 57                    )
 58                except OptimizeError:
 59                    pass
 60
 61        if not table_name and self._infer_schema:
 62            sources_without_schema = tuple(
 63                source
 64                for source, columns in self._get_all_source_columns().items()
 65                if not columns or "*" in columns
 66            )
 67            if len(sources_without_schema) == 1:
 68                table_name = sources_without_schema[0]
 69
 70        if table_name not in self.scope.selected_sources:
 71            return exp.to_identifier(table_name)
 72
 73        node, _ = self.scope.selected_sources.get(table_name)
 74
 75        if isinstance(node, exp.Query):
 76            while node and node.alias != table_name:
 77                node = node.parent
 78
 79        node_alias = node.args.get("alias")
 80        if node_alias:
 81            return exp.to_identifier(node_alias.this)
 82
 83        return exp.to_identifier(table_name)
 84
 85    @property
 86    def all_columns(self) -> t.Set[str]:
 87        """All available columns of all sources in this scope"""
 88        if self._all_columns is None:
 89            self._all_columns = {
 90                column for columns in self._get_all_source_columns().values() for column in columns
 91            }
 92        return self._all_columns
 93
 94    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 95        if isinstance(expression, exp.Select):
 96            return expression.named_selects
 97        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 98            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
 99            return self.get_source_columns_from_set_op(expression.this)
100        if not isinstance(expression, exp.SetOperation):
101            raise OptimizeError(f"Unknown set operation: {expression}")
102
103        set_op = expression
104
105        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
106        on_column_list = set_op.args.get("on")
107
108        if on_column_list:
109            # The resulting columns are the columns in the ON clause:
110            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
111            columns = [col.name for col in on_column_list]
112        elif set_op.side or set_op.kind:
113            side = set_op.side
114            kind = set_op.kind
115
116            # Visit the children UNIONs (if any) in a post-order traversal
117            left = self.get_source_columns_from_set_op(set_op.left)
118            right = self.get_source_columns_from_set_op(set_op.right)
119
120            # We use dict.fromkeys to deduplicate keys and maintain insertion order
121            if side == "LEFT":
122                columns = left
123            elif side == "FULL":
124                columns = list(dict.fromkeys(left + right))
125            elif kind == "INNER":
126                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
127        else:
128            columns = set_op.named_selects
129
130        return columns
131
132    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
133        """Resolve the source columns for a given source `name`."""
134        cache_key = (name, only_visible)
135        if cache_key not in self._get_source_columns_cache:
136            if name not in self.scope.sources:
137                raise OptimizeError(f"Unknown table: {name}")
138
139            source = self.scope.sources[name]
140
141            if isinstance(source, exp.Table):
142                columns = self.schema.column_names(source, only_visible)
143            elif isinstance(source, Scope) and isinstance(
144                source.expression, (exp.Values, exp.Unnest)
145            ):
146                columns = source.expression.named_selects
147
148                # in bigquery, unnest structs are automatically scoped as tables, so you can
149                # directly select a struct field in a query.
150                # this handles the case where the unnest is statically defined.
151                if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
152                    unnest = source.expression
153
154                    # if type is not annotated yet, try to get it from the schema
155                    if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
156                        unnest_expr = seq_get(unnest.expressions, 0)
157                        if isinstance(unnest_expr, exp.Column) and self.scope.parent:
158                            col_type = self._get_unnest_column_type(unnest_expr)
159                            # extract element type if it's an ARRAY
160                            if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
161                                element_types = col_type.expressions
162                                if element_types:
163                                    unnest.type = element_types[0].copy()
164                            else:
165                                if col_type:
166                                    unnest.type = col_type.copy()
167                    # check if the result type is a STRUCT - extract struct field names
168                    if unnest.is_type(exp.DataType.Type.STRUCT):
169                        for k in unnest.type.expressions:  # type: ignore
170                            columns.append(k.name)
171            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
172                columns = self.get_source_columns_from_set_op(source.expression)
173
174            else:
175                select = seq_get(source.expression.selects, 0)
176
177                if isinstance(select, exp.QueryTransform):
178                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
179                    schema = select.args.get("schema")
180                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
181                else:
182                    columns = source.expression.named_selects
183
184            node, _ = self.scope.selected_sources.get(name) or (None, None)
185            if isinstance(node, Scope):
186                column_aliases = node.expression.alias_column_names
187            elif isinstance(node, exp.Expression):
188                column_aliases = node.alias_column_names
189            else:
190                column_aliases = []
191
192            if column_aliases:
193                # If the source's columns are aliased, their aliases shadow the corresponding column names.
194                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
195                columns = [
196                    alias or name
197                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
198                ]
199
200            self._get_source_columns_cache[cache_key] = columns
201
202        return self._get_source_columns_cache[cache_key]
203
204    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
205        if self._source_columns is None:
206            self._source_columns = {
207                source_name: self.get_source_columns(source_name)
208                for source_name, source in itertools.chain(
209                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
210                )
211            }
212        return self._source_columns
213
214    def _get_table_name_from_sources(
215        self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
216    ) -> t.Optional[str]:
217        if not source_columns:
218            # If not supplied, get all sources to calculate unambiguous columns
219            if self._unambiguous_columns is None:
220                self._unambiguous_columns = self._get_unambiguous_columns(
221                    self._get_all_source_columns()
222                )
223
224            unambiguous_columns = self._unambiguous_columns
225        else:
226            unambiguous_columns = self._get_unambiguous_columns(source_columns)
227
228        return unambiguous_columns.get(column_name)
229
230    def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
231        """
232        Check if a column participating in a join can be qualified based on the source order.
233        """
234        args = self.scope.expression.args
235        joins = args.get("joins")
236
237        if not joins or args.get("laterals") or args.get("pivots"):
238            # Feature gap: We currently don't try to disambiguate columns if other sources
239            # (e.g laterals, pivots) exist alongside joins
240            return None
241
242        join_ancestor = column.find_ancestor(exp.Join, exp.Select)
243
244        if (
245            isinstance(join_ancestor, exp.Join)
246            and join_ancestor.alias_or_name in self.scope.selected_sources
247        ):
248            # Ensure that the found ancestor is a join that contains an actual source,
249            # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
250            return join_ancestor
251
252        return None
253
254    def _get_available_source_columns(
255        self, join_ancestor: exp.Join
256    ) -> t.Dict[str, t.Sequence[str]]:
257        """
258        Get the source columns that are available at the point where a column is referenced.
259
260        For columns in JOIN conditions, this only includes tables that have been joined
261        up to that point. Example:
262
263        ```
264        SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
265        ```                                                        ^
266                                                                   |
267                                +----------------------------------+
268                                |
269                                ⌄
270        The unqualified column `c` is not ambiguous if no other sources up until that
271        join i.e t_1, ..., t_n, contain a column named `c`.
272
273        """
274        args = self.scope.expression.args
275
276        # Collect tables in order: FROM clause tables + joined tables up to current join
277        from_name = args["from_"].alias_or_name
278        available_sources = {from_name: self.get_source_columns(from_name)}
279
280        for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
281            available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
282
283        return available_sources
284
285    def _get_unambiguous_columns(
286        self, source_columns: t.Dict[str, t.Sequence[str]]
287    ) -> t.Mapping[str, str]:
288        """
289        Find all the unambiguous columns in sources.
290
291        Args:
292            source_columns: Mapping of names to source columns.
293
294        Returns:
295            Mapping of column name to source name.
296        """
297        if not source_columns:
298            return {}
299
300        source_columns_pairs = list(source_columns.items())
301
302        first_table, first_columns = source_columns_pairs[0]
303
304        if len(source_columns_pairs) == 1:
305            # Performance optimization - avoid copying first_columns if there is only one table.
306            return SingleValuedMapping(first_columns, first_table)
307
308        # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases
309        # from alias.columns[0] to their source names. This is used to resolve shadowing
310        # where an UNNEST alias shadows a column name from another table.
311        unnest_original_aliases: t.Dict[str, str] = {}
312        if self.dialect.UNNEST_COLUMN_ONLY:
313            unnest_original_aliases = {
314                alias_arg.columns[0].name: source_name
315                for source_name, source in self.scope.sources.items()
316                if (
317                    isinstance(source.expression, exp.Unnest)
318                    and (alias_arg := source.expression.args.get("alias"))
319                    and alias_arg.columns
320                )
321            }
322
323        unambiguous_columns = {col: first_table for col in first_columns}
324        all_columns = set(unambiguous_columns)
325
326        for table, columns in source_columns_pairs[1:]:
327            unique = set(columns)
328            ambiguous = all_columns.intersection(unique)
329            all_columns.update(columns)
330
331            for column in ambiguous:
332                if column in unnest_original_aliases:
333                    unambiguous_columns[column] = unnest_original_aliases[column]
334                    continue
335
336                unambiguous_columns.pop(column, None)
337            for column in unique.difference(ambiguous):
338                unambiguous_columns[column] = table
339
340        return unambiguous_columns
341
342    def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
343        """
344        Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
345
346        Args:
347            column: The column expression being unnested.
348
349        Returns:
350            The DataType of the column, or None if not found.
351        """
352        scope = self.scope.parent
353
354        # if column is qualified, use that table, otherwise disambiguate using the resolver
355        if column.table:
356            table_name = column.table
357        else:
358            # use the parent scope's resolver to disambiguate the column
359            parent_resolver = Resolver(scope, self.schema, self._infer_schema)
360            table_identifier = parent_resolver.get_table(column)
361            if not table_identifier:
362                return None
363            table_name = table_identifier.name
364
365        source = scope.sources.get(table_name)
366        return self._get_column_type_from_scope(source, column) if source else None
367
368    def _get_column_type_from_scope(
369        self, source: t.Union[Scope, exp.Table], column: exp.Column
370    ) -> t.Optional[exp.DataType]:
371        """
372        Get a column's type by tracing through scopes/tables to find the base table.
373
374        Args:
375            source: The source to search - can be a Scope (to iterate its sources) or a Table.
376            column: The column to find the type for.
377
378        Returns:
379            The DataType of the column, or None if not found.
380        """
381        if isinstance(source, exp.Table):
382            # base table - get the column type from schema
383            col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column)
384            if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
385                return col_type
386        elif isinstance(source, Scope):
387            # iterate over all sources in the scope
388            for source_name, nested_source in source.sources.items():
389                col_type = self._get_column_type_from_scope(nested_source, column)
390                if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
391                    return col_type
392
393        return None
class Resolver:
 17class Resolver:
 18    """
 19    Helper for resolving columns.
 20
 21    This is a class so we can lazily load some things and easily share them across functions.
 22    """
 23
 24    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
 25        self.scope = scope
 26        self.schema = schema
 27        self.dialect = schema.dialect or Dialect()
 28        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
 29        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
 30        self._all_columns: t.Optional[t.Set[str]] = None
 31        self._infer_schema = infer_schema
 32        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
 33
 34    def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
 35        """
 36        Get the table for a column name.
 37
 38        Args:
 39            column: The column expression (or column name) to find the table for.
 40        Returns:
 41            The table name if it can be found/inferred.
 42        """
 43        column_name = column if isinstance(column, str) else column.name
 44
 45        table_name = self._get_table_name_from_sources(column_name)
 46
 47        if not table_name and isinstance(column, exp.Column):
 48            # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
 49            # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
 50            # we may be able to disambiguate based on the source order.
 51            if join_context := self._get_column_join_context(column):
 52                # In this case, the return value will be the join that _may_ be able to disambiguate the column
 53                # and we can use the source columns available at that join to get the table name
 54                # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
 55                try:
 56                    table_name = self._get_table_name_from_sources(
 57                        column_name, self._get_available_source_columns(join_context)
 58                    )
 59                except OptimizeError:
 60                    pass
 61
 62        if not table_name and self._infer_schema:
 63            sources_without_schema = tuple(
 64                source
 65                for source, columns in self._get_all_source_columns().items()
 66                if not columns or "*" in columns
 67            )
 68            if len(sources_without_schema) == 1:
 69                table_name = sources_without_schema[0]
 70
 71        if table_name not in self.scope.selected_sources:
 72            return exp.to_identifier(table_name)
 73
 74        node, _ = self.scope.selected_sources.get(table_name)
 75
 76        if isinstance(node, exp.Query):
 77            while node and node.alias != table_name:
 78                node = node.parent
 79
 80        node_alias = node.args.get("alias")
 81        if node_alias:
 82            return exp.to_identifier(node_alias.this)
 83
 84        return exp.to_identifier(table_name)
 85
 86    @property
 87    def all_columns(self) -> t.Set[str]:
 88        """All available columns of all sources in this scope"""
 89        if self._all_columns is None:
 90            self._all_columns = {
 91                column for columns in self._get_all_source_columns().values() for column in columns
 92            }
 93        return self._all_columns
 94
 95    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 96        if isinstance(expression, exp.Select):
 97            return expression.named_selects
 98        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 99            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
100            return self.get_source_columns_from_set_op(expression.this)
101        if not isinstance(expression, exp.SetOperation):
102            raise OptimizeError(f"Unknown set operation: {expression}")
103
104        set_op = expression
105
106        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
107        on_column_list = set_op.args.get("on")
108
109        if on_column_list:
110            # The resulting columns are the columns in the ON clause:
111            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
112            columns = [col.name for col in on_column_list]
113        elif set_op.side or set_op.kind:
114            side = set_op.side
115            kind = set_op.kind
116
117            # Visit the children UNIONs (if any) in a post-order traversal
118            left = self.get_source_columns_from_set_op(set_op.left)
119            right = self.get_source_columns_from_set_op(set_op.right)
120
121            # We use dict.fromkeys to deduplicate keys and maintain insertion order
122            if side == "LEFT":
123                columns = left
124            elif side == "FULL":
125                columns = list(dict.fromkeys(left + right))
126            elif kind == "INNER":
127                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
128        else:
129            columns = set_op.named_selects
130
131        return columns
132
133    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
134        """Resolve the source columns for a given source `name`."""
135        cache_key = (name, only_visible)
136        if cache_key not in self._get_source_columns_cache:
137            if name not in self.scope.sources:
138                raise OptimizeError(f"Unknown table: {name}")
139
140            source = self.scope.sources[name]
141
142            if isinstance(source, exp.Table):
143                columns = self.schema.column_names(source, only_visible)
144            elif isinstance(source, Scope) and isinstance(
145                source.expression, (exp.Values, exp.Unnest)
146            ):
147                columns = source.expression.named_selects
148
149                # in bigquery, unnest structs are automatically scoped as tables, so you can
150                # directly select a struct field in a query.
151                # this handles the case where the unnest is statically defined.
152                if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
153                    unnest = source.expression
154
155                    # if type is not annotated yet, try to get it from the schema
156                    if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
157                        unnest_expr = seq_get(unnest.expressions, 0)
158                        if isinstance(unnest_expr, exp.Column) and self.scope.parent:
159                            col_type = self._get_unnest_column_type(unnest_expr)
160                            # extract element type if it's an ARRAY
161                            if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
162                                element_types = col_type.expressions
163                                if element_types:
164                                    unnest.type = element_types[0].copy()
165                            else:
166                                if col_type:
167                                    unnest.type = col_type.copy()
168                    # check if the result type is a STRUCT - extract struct field names
169                    if unnest.is_type(exp.DataType.Type.STRUCT):
170                        for k in unnest.type.expressions:  # type: ignore
171                            columns.append(k.name)
172            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
173                columns = self.get_source_columns_from_set_op(source.expression)
174
175            else:
176                select = seq_get(source.expression.selects, 0)
177
178                if isinstance(select, exp.QueryTransform):
179                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
180                    schema = select.args.get("schema")
181                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
182                else:
183                    columns = source.expression.named_selects
184
185            node, _ = self.scope.selected_sources.get(name) or (None, None)
186            if isinstance(node, Scope):
187                column_aliases = node.expression.alias_column_names
188            elif isinstance(node, exp.Expression):
189                column_aliases = node.alias_column_names
190            else:
191                column_aliases = []
192
193            if column_aliases:
194                # If the source's columns are aliased, their aliases shadow the corresponding column names.
195                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
196                columns = [
197                    alias or name
198                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
199                ]
200
201            self._get_source_columns_cache[cache_key] = columns
202
203        return self._get_source_columns_cache[cache_key]
204
205    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
206        if self._source_columns is None:
207            self._source_columns = {
208                source_name: self.get_source_columns(source_name)
209                for source_name, source in itertools.chain(
210                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
211                )
212            }
213        return self._source_columns
214
215    def _get_table_name_from_sources(
216        self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
217    ) -> t.Optional[str]:
218        if not source_columns:
219            # If not supplied, get all sources to calculate unambiguous columns
220            if self._unambiguous_columns is None:
221                self._unambiguous_columns = self._get_unambiguous_columns(
222                    self._get_all_source_columns()
223                )
224
225            unambiguous_columns = self._unambiguous_columns
226        else:
227            unambiguous_columns = self._get_unambiguous_columns(source_columns)
228
229        return unambiguous_columns.get(column_name)
230
231    def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
232        """
233        Check if a column participating in a join can be qualified based on the source order.
234        """
235        args = self.scope.expression.args
236        joins = args.get("joins")
237
238        if not joins or args.get("laterals") or args.get("pivots"):
239            # Feature gap: We currently don't try to disambiguate columns if other sources
240            # (e.g laterals, pivots) exist alongside joins
241            return None
242
243        join_ancestor = column.find_ancestor(exp.Join, exp.Select)
244
245        if (
246            isinstance(join_ancestor, exp.Join)
247            and join_ancestor.alias_or_name in self.scope.selected_sources
248        ):
249            # Ensure that the found ancestor is a join that contains an actual source,
250            # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
251            return join_ancestor
252
253        return None
254
255    def _get_available_source_columns(
256        self, join_ancestor: exp.Join
257    ) -> t.Dict[str, t.Sequence[str]]:
258        """
259        Get the source columns that are available at the point where a column is referenced.
260
261        For columns in JOIN conditions, this only includes tables that have been joined
262        up to that point. Example:
263
264        ```
265        SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
266        ```                                                        ^
267                                                                   |
268                                +----------------------------------+
269                                |
270                                ⌄
271        The unqualified column `c` is not ambiguous if no other sources up until that
272        join i.e t_1, ..., t_n, contain a column named `c`.
273
274        """
275        args = self.scope.expression.args
276
277        # Collect tables in order: FROM clause tables + joined tables up to current join
278        from_name = args["from_"].alias_or_name
279        available_sources = {from_name: self.get_source_columns(from_name)}
280
281        for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
282            available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
283
284        return available_sources
285
286    def _get_unambiguous_columns(
287        self, source_columns: t.Dict[str, t.Sequence[str]]
288    ) -> t.Mapping[str, str]:
289        """
290        Find all the unambiguous columns in sources.
291
292        Args:
293            source_columns: Mapping of names to source columns.
294
295        Returns:
296            Mapping of column name to source name.
297        """
298        if not source_columns:
299            return {}
300
301        source_columns_pairs = list(source_columns.items())
302
303        first_table, first_columns = source_columns_pairs[0]
304
305        if len(source_columns_pairs) == 1:
306            # Performance optimization - avoid copying first_columns if there is only one table.
307            return SingleValuedMapping(first_columns, first_table)
308
309        # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases
310        # from alias.columns[0] to their source names. This is used to resolve shadowing
311        # where an UNNEST alias shadows a column name from another table.
312        unnest_original_aliases: t.Dict[str, str] = {}
313        if self.dialect.UNNEST_COLUMN_ONLY:
314            unnest_original_aliases = {
315                alias_arg.columns[0].name: source_name
316                for source_name, source in self.scope.sources.items()
317                if (
318                    isinstance(source.expression, exp.Unnest)
319                    and (alias_arg := source.expression.args.get("alias"))
320                    and alias_arg.columns
321                )
322            }
323
324        unambiguous_columns = {col: first_table for col in first_columns}
325        all_columns = set(unambiguous_columns)
326
327        for table, columns in source_columns_pairs[1:]:
328            unique = set(columns)
329            ambiguous = all_columns.intersection(unique)
330            all_columns.update(columns)
331
332            for column in ambiguous:
333                if column in unnest_original_aliases:
334                    unambiguous_columns[column] = unnest_original_aliases[column]
335                    continue
336
337                unambiguous_columns.pop(column, None)
338            for column in unique.difference(ambiguous):
339                unambiguous_columns[column] = table
340
341        return unambiguous_columns
342
343    def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
344        """
345        Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
346
347        Args:
348            column: The column expression being unnested.
349
350        Returns:
351            The DataType of the column, or None if not found.
352        """
353        scope = self.scope.parent
354
355        # if column is qualified, use that table, otherwise disambiguate using the resolver
356        if column.table:
357            table_name = column.table
358        else:
359            # use the parent scope's resolver to disambiguate the column
360            parent_resolver = Resolver(scope, self.schema, self._infer_schema)
361            table_identifier = parent_resolver.get_table(column)
362            if not table_identifier:
363                return None
364            table_name = table_identifier.name
365
366        source = scope.sources.get(table_name)
367        return self._get_column_type_from_scope(source, column) if source else None
368
369    def _get_column_type_from_scope(
370        self, source: t.Union[Scope, exp.Table], column: exp.Column
371    ) -> t.Optional[exp.DataType]:
372        """
373        Get a column's type by tracing through scopes/tables to find the base table.
374
375        Args:
376            source: The source to search - can be a Scope (to iterate its sources) or a Table.
377            column: The column to find the type for.
378
379        Returns:
380            The DataType of the column, or None if not found.
381        """
382        if isinstance(source, exp.Table):
383            # base table - get the column type from schema
384            col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column)
385            if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
386                return col_type
387        elif isinstance(source, Scope):
388            # iterate over all sources in the scope
389            for source_name, nested_source in source.sources.items():
390                col_type = self._get_column_type_from_scope(nested_source, column)
391                if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
392                    return col_type
393
394        return None

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
24    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
25        self.scope = scope
26        self.schema = schema
27        self.dialect = schema.dialect or Dialect()
28        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
29        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
30        self._all_columns: t.Optional[t.Set[str]] = None
31        self._infer_schema = infer_schema
32        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
dialect
def get_table( self, column: str | sqlglot.expressions.Column) -> Optional[sqlglot.expressions.Identifier]:
34    def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
35        """
36        Get the table for a column name.
37
38        Args:
39            column: The column expression (or column name) to find the table for.
40        Returns:
41            The table name if it can be found/inferred.
42        """
43        column_name = column if isinstance(column, str) else column.name
44
45        table_name = self._get_table_name_from_sources(column_name)
46
47        if not table_name and isinstance(column, exp.Column):
48            # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
49            # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
50            # we may be able to disambiguate based on the source order.
51            if join_context := self._get_column_join_context(column):
52                # In this case, the return value will be the join that _may_ be able to disambiguate the column
53                # and we can use the source columns available at that join to get the table name
54                # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below
55                try:
56                    table_name = self._get_table_name_from_sources(
57                        column_name, self._get_available_source_columns(join_context)
58                    )
59                except OptimizeError:
60                    pass
61
62        if not table_name and self._infer_schema:
63            sources_without_schema = tuple(
64                source
65                for source, columns in self._get_all_source_columns().items()
66                if not columns or "*" in columns
67            )
68            if len(sources_without_schema) == 1:
69                table_name = sources_without_schema[0]
70
71        if table_name not in self.scope.selected_sources:
72            return exp.to_identifier(table_name)
73
74        node, _ = self.scope.selected_sources.get(table_name)
75
76        if isinstance(node, exp.Query):
77            while node and node.alias != table_name:
78                node = node.parent
79
80        node_alias = node.args.get("alias")
81        if node_alias:
82            return exp.to_identifier(node_alias.this)
83
84        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column: The column expression (or column name) to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
86    @property
87    def all_columns(self) -> t.Set[str]:
88        """All available columns of all sources in this scope"""
89        if self._all_columns is None:
90            self._all_columns = {
91                column for columns in self._get_all_source_columns().values() for column in columns
92            }
93        return self._all_columns

All available columns of all sources in this scope

def get_source_columns_from_set_op(self, expression: sqlglot.expressions.Expression) -> List[str]:
 95    def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]:
 96        if isinstance(expression, exp.Select):
 97            return expression.named_selects
 98        if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation):
 99            # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting
100            return self.get_source_columns_from_set_op(expression.this)
101        if not isinstance(expression, exp.SetOperation):
102            raise OptimizeError(f"Unknown set operation: {expression}")
103
104        set_op = expression
105
106        # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
107        on_column_list = set_op.args.get("on")
108
109        if on_column_list:
110            # The resulting columns are the columns in the ON clause:
111            # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
112            columns = [col.name for col in on_column_list]
113        elif set_op.side or set_op.kind:
114            side = set_op.side
115            kind = set_op.kind
116
117            # Visit the children UNIONs (if any) in a post-order traversal
118            left = self.get_source_columns_from_set_op(set_op.left)
119            right = self.get_source_columns_from_set_op(set_op.right)
120
121            # We use dict.fromkeys to deduplicate keys and maintain insertion order
122            if side == "LEFT":
123                columns = left
124            elif side == "FULL":
125                columns = list(dict.fromkeys(left + right))
126            elif kind == "INNER":
127                columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
128        else:
129            columns = set_op.named_selects
130
131        return columns
def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
133    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
134        """Resolve the source columns for a given source `name`."""
135        cache_key = (name, only_visible)
136        if cache_key not in self._get_source_columns_cache:
137            if name not in self.scope.sources:
138                raise OptimizeError(f"Unknown table: {name}")
139
140            source = self.scope.sources[name]
141
142            if isinstance(source, exp.Table):
143                columns = self.schema.column_names(source, only_visible)
144            elif isinstance(source, Scope) and isinstance(
145                source.expression, (exp.Values, exp.Unnest)
146            ):
147                columns = source.expression.named_selects
148
149                # in bigquery, unnest structs are automatically scoped as tables, so you can
150                # directly select a struct field in a query.
151                # this handles the case where the unnest is statically defined.
152                if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
153                    unnest = source.expression
154
155                    # if type is not annotated yet, try to get it from the schema
156                    if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
157                        unnest_expr = seq_get(unnest.expressions, 0)
158                        if isinstance(unnest_expr, exp.Column) and self.scope.parent:
159                            col_type = self._get_unnest_column_type(unnest_expr)
160                            # extract element type if it's an ARRAY
161                            if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
162                                element_types = col_type.expressions
163                                if element_types:
164                                    unnest.type = element_types[0].copy()
165                            else:
166                                if col_type:
167                                    unnest.type = col_type.copy()
168                    # check if the result type is a STRUCT - extract struct field names
169                    if unnest.is_type(exp.DataType.Type.STRUCT):
170                        for k in unnest.type.expressions:  # type: ignore
171                            columns.append(k.name)
172            elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
173                columns = self.get_source_columns_from_set_op(source.expression)
174
175            else:
176                select = seq_get(source.expression.selects, 0)
177
178                if isinstance(select, exp.QueryTransform):
179                    # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html
180                    schema = select.args.get("schema")
181                    columns = [c.name for c in schema.expressions] if schema else ["key", "value"]
182                else:
183                    columns = source.expression.named_selects
184
185            node, _ = self.scope.selected_sources.get(name) or (None, None)
186            if isinstance(node, Scope):
187                column_aliases = node.expression.alias_column_names
188            elif isinstance(node, exp.Expression):
189                column_aliases = node.alias_column_names
190            else:
191                column_aliases = []
192
193            if column_aliases:
194                # If the source's columns are aliased, their aliases shadow the corresponding column names.
195                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
196                columns = [
197                    alias or name
198                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
199                ]
200
201            self._get_source_columns_cache[cache_key] = columns
202
203        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.