sqlglot.optimizer.resolver
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.scope import Scope 11 12if t.TYPE_CHECKING: 13 from sqlglot.schema import Schema 14 from collections.abc import Sequence, Mapping 15 16 17class Resolver: 18 """ 19 Helper for resolving columns. 20 21 This is a class so we can lazily load some things and easily share them across functions. 22 """ 23 24 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 25 self.scope = scope 26 self.schema = schema 27 self.dialect = schema.dialect or Dialect() 28 self._source_columns: dict[str, Sequence[str]] | None = None 29 self._unambiguous_columns: Mapping[str, str] | None = None 30 self._all_columns: set[str] | None = None 31 self._infer_schema = infer_schema 32 self._get_source_columns_cache: dict[tuple[str, bool], Sequence[str]] = {} 33 34 def get_table(self, column: str | exp.Column) -> exp.Identifier | None: 35 """ 36 Get the table for a column name. 37 38 Args: 39 column: The column expression (or column name) to find the table for. 40 Returns: 41 The table name if it can be found/inferred. 42 """ 43 column_name = column if isinstance(column, str) else column.name 44 45 table_name = self._get_table_name_from_sources(column_name) 46 47 if not table_name and isinstance(column, exp.Column): 48 # Fall-back case: If we couldn't find the `table_name` from ALL of the sources, 49 # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition, 50 # we may be able to disambiguate based on the source order. 51 if join_context := self._get_column_join_context(column): 52 # In this case, the return value will be the join that _may_ be able to disambiguate the column 53 # and we can use the source columns available at that join to get the table name 54 # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below 55 try: 56 table_name = self._get_table_name_from_sources( 57 column_name, self._get_available_source_columns(join_context) 58 ) 59 except OptimizeError: 60 pass 61 62 if not table_name and self._infer_schema: 63 sources_without_schema = tuple( 64 source 65 for source, columns in self._get_all_source_columns().items() 66 if not columns or "*" in columns 67 ) 68 if len(sources_without_schema) == 1: 69 table_name = sources_without_schema[0] 70 71 if table_name not in self.scope.selected_sources: 72 return exp.to_identifier(table_name) 73 74 node: exp.Expr = self.scope.selected_sources[table_name][0] 75 76 if isinstance(node, exp.Query): 77 while node and node.alias != table_name and node.parent: 78 node = node.parent 79 80 node_alias = node.args.get("alias") 81 if node_alias: 82 return exp.to_identifier(node_alias.this) 83 84 return exp.to_identifier(table_name) 85 86 @property 87 def all_columns(self) -> set[str]: 88 """All available columns of all sources in this scope""" 89 if self._all_columns is None: 90 self._all_columns = { 91 column for columns in self._get_all_source_columns().values() for column in columns 92 } 93 return self._all_columns 94 95 def get_source_columns_from_set_op(self, expression: exp.Expr) -> list[str]: 96 if isinstance(expression, exp.Select): 97 return expression.named_selects 98 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 99 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 100 return self.get_source_columns_from_set_op(expression.this) 101 if not isinstance(expression, exp.SetOperation): 102 raise OptimizeError(f"Unknown set operation: {expression}") 103 104 set_op = expression 105 106 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 107 on_column_list = set_op.args.get("on") 108 109 if on_column_list: 110 # The resulting columns are the columns in the ON clause: 111 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 112 columns = [col.name for col in on_column_list] 113 elif set_op.side or set_op.kind: 114 side = set_op.side 115 kind = set_op.kind 116 117 # Visit the children UNIONs (if any) in a post-order traversal 118 left = self.get_source_columns_from_set_op(set_op.left) 119 right = self.get_source_columns_from_set_op(set_op.right) 120 121 # We use dict.fromkeys to deduplicate keys and maintain insertion order 122 if side == "LEFT": 123 columns = left 124 elif side == "FULL": 125 columns = list(dict.fromkeys(left + right)) 126 elif kind == "INNER": 127 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 128 else: 129 columns = set_op.named_selects 130 131 return columns 132 133 def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]: 134 """Resolve the source columns for a given source `name`.""" 135 cache_key = (name, only_visible) 136 if cache_key not in self._get_source_columns_cache: 137 if name not in self.scope.sources: 138 raise OptimizeError(f"Unknown table: {name}") 139 140 source = self.scope.sources[name] 141 142 if isinstance(source, exp.Table): 143 columns = self.schema.column_names(source, only_visible) 144 elif isinstance(source, Scope) and isinstance( 145 source.expression, (exp.Values, exp.Unnest) 146 ): 147 columns = source.expression.named_selects 148 149 # in bigquery, unnest structs are automatically scoped as tables, so you can 150 # directly select a struct field in a query. 151 # this handles the case where the unnest is statically defined. 152 if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest): 153 unnest = source.expression 154 155 # if type is not annotated yet, try to get it from the schema 156 if not unnest.type or unnest.type.is_type(exp.DType.UNKNOWN): 157 unnest_expr = seq_get(unnest.expressions, 0) 158 if isinstance(unnest_expr, exp.Column) and self.scope.parent: 159 col_type = self._get_unnest_column_type(unnest_expr) 160 # extract element type if it's an ARRAY 161 if col_type and col_type.is_type(exp.DType.ARRAY): 162 element_types = col_type.expressions 163 if element_types: 164 unnest.type = element_types[0].copy() 165 else: 166 if col_type: 167 unnest.type = col_type.copy() 168 # check if the result type is a STRUCT - extract struct field names 169 if unnest.is_type(exp.DType.STRUCT): 170 for k in unnest.type.expressions: # type: ignore 171 columns.append(k.name) 172 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 173 columns = self.get_source_columns_from_set_op(source.expression) 174 else: 175 selectable = source.expression.assert_is(exp.Selectable) 176 select = seq_get(selectable.selects, 0) 177 178 if isinstance(select, exp.QueryTransform): 179 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 180 schema = select.args.get("schema") 181 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 182 else: 183 columns = selectable.named_selects 184 185 node, _ = self.scope.selected_sources.get(name) or (None, None) 186 if isinstance(node, Scope): 187 column_aliases = node.expression.alias_column_names 188 elif isinstance(node, exp.Expr): 189 column_aliases = node.alias_column_names 190 else: 191 column_aliases = [] 192 193 if column_aliases: 194 # If the source's columns are aliased, their aliases shadow the corresponding column names. 195 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 196 columns = [ 197 alias or name 198 for (name, alias) in itertools.zip_longest(columns, column_aliases) 199 ] 200 201 self._get_source_columns_cache[cache_key] = columns 202 203 return self._get_source_columns_cache[cache_key] 204 205 def _get_all_source_columns(self) -> dict[str, Sequence[str]]: 206 if self._source_columns is None: 207 self._source_columns = { 208 source_name: self.get_source_columns(source_name) 209 for source_name, source in itertools.chain( 210 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 211 ) 212 } 213 return self._source_columns 214 215 def _get_table_name_from_sources( 216 self, column_name: str, source_columns: dict[str, Sequence[str]] | None = None 217 ) -> str | None: 218 if not source_columns: 219 # If not supplied, get all sources to calculate unambiguous columns 220 if self._unambiguous_columns is None: 221 self._unambiguous_columns = self._get_unambiguous_columns( 222 self._get_all_source_columns() 223 ) 224 225 unambiguous_columns = self._unambiguous_columns 226 else: 227 unambiguous_columns = self._get_unambiguous_columns(source_columns) 228 229 return unambiguous_columns.get(column_name) 230 231 def _get_column_join_context(self, column: exp.Column) -> exp.Join | None: 232 """ 233 Check if a column participating in a join can be qualified based on the source order. 234 """ 235 args = self.scope.expression.args 236 joins = args.get("joins") 237 238 if not joins or args.get("laterals") or args.get("pivots"): 239 # Feature gap: We currently don't try to disambiguate columns if other sources 240 # (e.g laterals, pivots) exist alongside joins 241 return None 242 243 join_ancestor = column.find_ancestor(exp.Join, exp.Select) 244 245 if ( 246 isinstance(join_ancestor, exp.Join) 247 and join_ancestor.alias_or_name in self.scope.selected_sources 248 ): 249 # Ensure that the found ancestor is a join that contains an actual source, 250 # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b` 251 return join_ancestor 252 253 return None 254 255 def _get_available_source_columns(self, join_ancestor: exp.Join) -> dict[str, Sequence[str]]: 256 """ 257 Get the source columns that are available at the point where a column is referenced. 258 259 For columns in JOIN conditions, this only includes tables that have been joined 260 up to that point. Example: 261 262 ``` 263 SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ... 264 ``` ^ 265 | 266 +----------------------------------+ 267 | 268 ⌄ 269 The unqualified column `c` is not ambiguous if no other sources up until that 270 join i.e t_1, ..., t_n, contain a column named `c`. 271 272 """ 273 args = self.scope.expression.args 274 275 # Collect tables in order: FROM clause tables + joined tables up to current join 276 from_name = args["from_"].alias_or_name 277 available_sources = {from_name: self.get_source_columns(from_name)} 278 279 for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]: 280 available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name) 281 282 return available_sources 283 284 def _get_unambiguous_columns( 285 self, source_columns: dict[str, Sequence[str]] 286 ) -> Mapping[str, str]: 287 """ 288 Find all the unambiguous columns in sources. 289 290 Args: 291 source_columns: Mapping of names to source columns. 292 293 Returns: 294 Mapping of column name to source name. 295 """ 296 if not source_columns: 297 return {} 298 299 source_columns_pairs = list(source_columns.items()) 300 301 first_table, first_columns = source_columns_pairs[0] 302 303 if len(source_columns_pairs) == 1: 304 # Performance optimization - avoid copying first_columns if there is only one table. 305 return SingleValuedMapping(first_columns, first_table) 306 307 # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases 308 # from alias.columns[0] to their source names. This is used to resolve shadowing 309 # where an UNNEST alias shadows a column name from another table. 310 unnest_original_aliases: dict[str, str] = {} 311 if self.dialect.UNNEST_COLUMN_ONLY: 312 unnest_original_aliases = { 313 alias_arg.columns[0].name: source_name 314 for source_name, source in self.scope.sources.items() 315 if ( 316 isinstance(source.expression, exp.Unnest) 317 and (alias_arg := source.expression.args.get("alias")) 318 and alias_arg.columns 319 ) 320 } 321 322 unambiguous_columns = {col: first_table for col in first_columns} 323 all_columns = set(unambiguous_columns) 324 325 for table, columns in source_columns_pairs[1:]: 326 unique = set(columns) 327 ambiguous = all_columns.intersection(unique) 328 all_columns.update(columns) 329 330 for column in ambiguous: 331 if column in unnest_original_aliases: 332 unambiguous_columns[column] = unnest_original_aliases[column] 333 continue 334 335 unambiguous_columns.pop(column, None) 336 for column in unique.difference(ambiguous): 337 unambiguous_columns[column] = table 338 339 return unambiguous_columns 340 341 def _get_unnest_column_type(self, column: exp.Column) -> exp.DataType | None: 342 """ 343 Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table. 344 345 Args: 346 column: The column expression being unnested. 347 348 Returns: 349 The DataType of the column, or None if not found. 350 """ 351 scope = self.scope.parent 352 assert scope 353 354 # if column is qualified, use that table, otherwise disambiguate using the resolver 355 if column.table: 356 table_name = column.table 357 else: 358 # use the parent scope's resolver to disambiguate the column 359 parent_resolver = Resolver(scope, self.schema, self._infer_schema) 360 table_identifier = parent_resolver.get_table(column) 361 if not table_identifier: 362 return None 363 table_name = table_identifier.name 364 365 source = scope.sources.get(table_name) 366 return self._get_column_type_from_scope(source, column) if source else None 367 368 def _get_column_type_from_scope( 369 self, source: Scope | exp.Table, column: exp.Column 370 ) -> exp.DataType | None: 371 """ 372 Get a column's type by tracing through scopes/tables to find the base table. 373 374 Args: 375 source: The source to search - can be a Scope (to iterate its sources) or a Table. 376 column: The column to find the type for. 377 378 Returns: 379 The DataType of the column, or None if not found. 380 """ 381 if isinstance(source, exp.Table): 382 # base table - get the column type from schema 383 col_type: exp.DataType | None = self.schema.get_column_type(source, column) 384 if col_type and not col_type.is_type(exp.DType.UNKNOWN): 385 return col_type 386 elif isinstance(source, Scope): 387 # iterate over all sources in the scope 388 for source_name, nested_source in source.sources.items(): 389 col_type = self._get_column_type_from_scope(nested_source, column) 390 if col_type and not col_type.is_type(exp.DType.UNKNOWN): 391 return col_type 392 393 return None
class
Resolver:
18class Resolver: 19 """ 20 Helper for resolving columns. 21 22 This is a class so we can lazily load some things and easily share them across functions. 23 """ 24 25 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 26 self.scope = scope 27 self.schema = schema 28 self.dialect = schema.dialect or Dialect() 29 self._source_columns: dict[str, Sequence[str]] | None = None 30 self._unambiguous_columns: Mapping[str, str] | None = None 31 self._all_columns: set[str] | None = None 32 self._infer_schema = infer_schema 33 self._get_source_columns_cache: dict[tuple[str, bool], Sequence[str]] = {} 34 35 def get_table(self, column: str | exp.Column) -> exp.Identifier | None: 36 """ 37 Get the table for a column name. 38 39 Args: 40 column: The column expression (or column name) to find the table for. 41 Returns: 42 The table name if it can be found/inferred. 43 """ 44 column_name = column if isinstance(column, str) else column.name 45 46 table_name = self._get_table_name_from_sources(column_name) 47 48 if not table_name and isinstance(column, exp.Column): 49 # Fall-back case: If we couldn't find the `table_name` from ALL of the sources, 50 # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition, 51 # we may be able to disambiguate based on the source order. 52 if join_context := self._get_column_join_context(column): 53 # In this case, the return value will be the join that _may_ be able to disambiguate the column 54 # and we can use the source columns available at that join to get the table name 55 # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below 56 try: 57 table_name = self._get_table_name_from_sources( 58 column_name, self._get_available_source_columns(join_context) 59 ) 60 except OptimizeError: 61 pass 62 63 if not table_name and self._infer_schema: 64 sources_without_schema = tuple( 65 source 66 for source, columns in self._get_all_source_columns().items() 67 if not columns or "*" in columns 68 ) 69 if len(sources_without_schema) == 1: 70 table_name = sources_without_schema[0] 71 72 if table_name not in self.scope.selected_sources: 73 return exp.to_identifier(table_name) 74 75 node: exp.Expr = self.scope.selected_sources[table_name][0] 76 77 if isinstance(node, exp.Query): 78 while node and node.alias != table_name and node.parent: 79 node = node.parent 80 81 node_alias = node.args.get("alias") 82 if node_alias: 83 return exp.to_identifier(node_alias.this) 84 85 return exp.to_identifier(table_name) 86 87 @property 88 def all_columns(self) -> set[str]: 89 """All available columns of all sources in this scope""" 90 if self._all_columns is None: 91 self._all_columns = { 92 column for columns in self._get_all_source_columns().values() for column in columns 93 } 94 return self._all_columns 95 96 def get_source_columns_from_set_op(self, expression: exp.Expr) -> list[str]: 97 if isinstance(expression, exp.Select): 98 return expression.named_selects 99 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 100 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 101 return self.get_source_columns_from_set_op(expression.this) 102 if not isinstance(expression, exp.SetOperation): 103 raise OptimizeError(f"Unknown set operation: {expression}") 104 105 set_op = expression 106 107 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 108 on_column_list = set_op.args.get("on") 109 110 if on_column_list: 111 # The resulting columns are the columns in the ON clause: 112 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 113 columns = [col.name for col in on_column_list] 114 elif set_op.side or set_op.kind: 115 side = set_op.side 116 kind = set_op.kind 117 118 # Visit the children UNIONs (if any) in a post-order traversal 119 left = self.get_source_columns_from_set_op(set_op.left) 120 right = self.get_source_columns_from_set_op(set_op.right) 121 122 # We use dict.fromkeys to deduplicate keys and maintain insertion order 123 if side == "LEFT": 124 columns = left 125 elif side == "FULL": 126 columns = list(dict.fromkeys(left + right)) 127 elif kind == "INNER": 128 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 129 else: 130 columns = set_op.named_selects 131 132 return columns 133 134 def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]: 135 """Resolve the source columns for a given source `name`.""" 136 cache_key = (name, only_visible) 137 if cache_key not in self._get_source_columns_cache: 138 if name not in self.scope.sources: 139 raise OptimizeError(f"Unknown table: {name}") 140 141 source = self.scope.sources[name] 142 143 if isinstance(source, exp.Table): 144 columns = self.schema.column_names(source, only_visible) 145 elif isinstance(source, Scope) and isinstance( 146 source.expression, (exp.Values, exp.Unnest) 147 ): 148 columns = source.expression.named_selects 149 150 # in bigquery, unnest structs are automatically scoped as tables, so you can 151 # directly select a struct field in a query. 152 # this handles the case where the unnest is statically defined. 153 if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest): 154 unnest = source.expression 155 156 # if type is not annotated yet, try to get it from the schema 157 if not unnest.type or unnest.type.is_type(exp.DType.UNKNOWN): 158 unnest_expr = seq_get(unnest.expressions, 0) 159 if isinstance(unnest_expr, exp.Column) and self.scope.parent: 160 col_type = self._get_unnest_column_type(unnest_expr) 161 # extract element type if it's an ARRAY 162 if col_type and col_type.is_type(exp.DType.ARRAY): 163 element_types = col_type.expressions 164 if element_types: 165 unnest.type = element_types[0].copy() 166 else: 167 if col_type: 168 unnest.type = col_type.copy() 169 # check if the result type is a STRUCT - extract struct field names 170 if unnest.is_type(exp.DType.STRUCT): 171 for k in unnest.type.expressions: # type: ignore 172 columns.append(k.name) 173 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 174 columns = self.get_source_columns_from_set_op(source.expression) 175 else: 176 selectable = source.expression.assert_is(exp.Selectable) 177 select = seq_get(selectable.selects, 0) 178 179 if isinstance(select, exp.QueryTransform): 180 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 181 schema = select.args.get("schema") 182 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 183 else: 184 columns = selectable.named_selects 185 186 node, _ = self.scope.selected_sources.get(name) or (None, None) 187 if isinstance(node, Scope): 188 column_aliases = node.expression.alias_column_names 189 elif isinstance(node, exp.Expr): 190 column_aliases = node.alias_column_names 191 else: 192 column_aliases = [] 193 194 if column_aliases: 195 # If the source's columns are aliased, their aliases shadow the corresponding column names. 196 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 197 columns = [ 198 alias or name 199 for (name, alias) in itertools.zip_longest(columns, column_aliases) 200 ] 201 202 self._get_source_columns_cache[cache_key] = columns 203 204 return self._get_source_columns_cache[cache_key] 205 206 def _get_all_source_columns(self) -> dict[str, Sequence[str]]: 207 if self._source_columns is None: 208 self._source_columns = { 209 source_name: self.get_source_columns(source_name) 210 for source_name, source in itertools.chain( 211 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 212 ) 213 } 214 return self._source_columns 215 216 def _get_table_name_from_sources( 217 self, column_name: str, source_columns: dict[str, Sequence[str]] | None = None 218 ) -> str | None: 219 if not source_columns: 220 # If not supplied, get all sources to calculate unambiguous columns 221 if self._unambiguous_columns is None: 222 self._unambiguous_columns = self._get_unambiguous_columns( 223 self._get_all_source_columns() 224 ) 225 226 unambiguous_columns = self._unambiguous_columns 227 else: 228 unambiguous_columns = self._get_unambiguous_columns(source_columns) 229 230 return unambiguous_columns.get(column_name) 231 232 def _get_column_join_context(self, column: exp.Column) -> exp.Join | None: 233 """ 234 Check if a column participating in a join can be qualified based on the source order. 235 """ 236 args = self.scope.expression.args 237 joins = args.get("joins") 238 239 if not joins or args.get("laterals") or args.get("pivots"): 240 # Feature gap: We currently don't try to disambiguate columns if other sources 241 # (e.g laterals, pivots) exist alongside joins 242 return None 243 244 join_ancestor = column.find_ancestor(exp.Join, exp.Select) 245 246 if ( 247 isinstance(join_ancestor, exp.Join) 248 and join_ancestor.alias_or_name in self.scope.selected_sources 249 ): 250 # Ensure that the found ancestor is a join that contains an actual source, 251 # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b` 252 return join_ancestor 253 254 return None 255 256 def _get_available_source_columns(self, join_ancestor: exp.Join) -> dict[str, Sequence[str]]: 257 """ 258 Get the source columns that are available at the point where a column is referenced. 259 260 For columns in JOIN conditions, this only includes tables that have been joined 261 up to that point. Example: 262 263 ``` 264 SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ... 265 ``` ^ 266 | 267 +----------------------------------+ 268 | 269 ⌄ 270 The unqualified column `c` is not ambiguous if no other sources up until that 271 join i.e t_1, ..., t_n, contain a column named `c`. 272 273 """ 274 args = self.scope.expression.args 275 276 # Collect tables in order: FROM clause tables + joined tables up to current join 277 from_name = args["from_"].alias_or_name 278 available_sources = {from_name: self.get_source_columns(from_name)} 279 280 for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]: 281 available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name) 282 283 return available_sources 284 285 def _get_unambiguous_columns( 286 self, source_columns: dict[str, Sequence[str]] 287 ) -> Mapping[str, str]: 288 """ 289 Find all the unambiguous columns in sources. 290 291 Args: 292 source_columns: Mapping of names to source columns. 293 294 Returns: 295 Mapping of column name to source name. 296 """ 297 if not source_columns: 298 return {} 299 300 source_columns_pairs = list(source_columns.items()) 301 302 first_table, first_columns = source_columns_pairs[0] 303 304 if len(source_columns_pairs) == 1: 305 # Performance optimization - avoid copying first_columns if there is only one table. 306 return SingleValuedMapping(first_columns, first_table) 307 308 # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases 309 # from alias.columns[0] to their source names. This is used to resolve shadowing 310 # where an UNNEST alias shadows a column name from another table. 311 unnest_original_aliases: dict[str, str] = {} 312 if self.dialect.UNNEST_COLUMN_ONLY: 313 unnest_original_aliases = { 314 alias_arg.columns[0].name: source_name 315 for source_name, source in self.scope.sources.items() 316 if ( 317 isinstance(source.expression, exp.Unnest) 318 and (alias_arg := source.expression.args.get("alias")) 319 and alias_arg.columns 320 ) 321 } 322 323 unambiguous_columns = {col: first_table for col in first_columns} 324 all_columns = set(unambiguous_columns) 325 326 for table, columns in source_columns_pairs[1:]: 327 unique = set(columns) 328 ambiguous = all_columns.intersection(unique) 329 all_columns.update(columns) 330 331 for column in ambiguous: 332 if column in unnest_original_aliases: 333 unambiguous_columns[column] = unnest_original_aliases[column] 334 continue 335 336 unambiguous_columns.pop(column, None) 337 for column in unique.difference(ambiguous): 338 unambiguous_columns[column] = table 339 340 return unambiguous_columns 341 342 def _get_unnest_column_type(self, column: exp.Column) -> exp.DataType | None: 343 """ 344 Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table. 345 346 Args: 347 column: The column expression being unnested. 348 349 Returns: 350 The DataType of the column, or None if not found. 351 """ 352 scope = self.scope.parent 353 assert scope 354 355 # if column is qualified, use that table, otherwise disambiguate using the resolver 356 if column.table: 357 table_name = column.table 358 else: 359 # use the parent scope's resolver to disambiguate the column 360 parent_resolver = Resolver(scope, self.schema, self._infer_schema) 361 table_identifier = parent_resolver.get_table(column) 362 if not table_identifier: 363 return None 364 table_name = table_identifier.name 365 366 source = scope.sources.get(table_name) 367 return self._get_column_type_from_scope(source, column) if source else None 368 369 def _get_column_type_from_scope( 370 self, source: Scope | exp.Table, column: exp.Column 371 ) -> exp.DataType | None: 372 """ 373 Get a column's type by tracing through scopes/tables to find the base table. 374 375 Args: 376 source: The source to search - can be a Scope (to iterate its sources) or a Table. 377 column: The column to find the type for. 378 379 Returns: 380 The DataType of the column, or None if not found. 381 """ 382 if isinstance(source, exp.Table): 383 # base table - get the column type from schema 384 col_type: exp.DataType | None = self.schema.get_column_type(source, column) 385 if col_type and not col_type.is_type(exp.DType.UNKNOWN): 386 return col_type 387 elif isinstance(source, Scope): 388 # iterate over all sources in the scope 389 for source_name, nested_source in source.sources.items(): 390 col_type = self._get_column_type_from_scope(nested_source, column) 391 if col_type and not col_type.is_type(exp.DType.UNKNOWN): 392 return col_type 393 394 return None
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
25 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 26 self.scope = scope 27 self.schema = schema 28 self.dialect = schema.dialect or Dialect() 29 self._source_columns: dict[str, Sequence[str]] | None = None 30 self._unambiguous_columns: Mapping[str, str] | None = None 31 self._all_columns: set[str] | None = None 32 self._infer_schema = infer_schema 33 self._get_source_columns_cache: dict[tuple[str, bool], Sequence[str]] = {}
def
get_table( self, column: str | sqlglot.expressions.core.Column) -> sqlglot.expressions.core.Identifier | None:
35 def get_table(self, column: str | exp.Column) -> exp.Identifier | None: 36 """ 37 Get the table for a column name. 38 39 Args: 40 column: The column expression (or column name) to find the table for. 41 Returns: 42 The table name if it can be found/inferred. 43 """ 44 column_name = column if isinstance(column, str) else column.name 45 46 table_name = self._get_table_name_from_sources(column_name) 47 48 if not table_name and isinstance(column, exp.Column): 49 # Fall-back case: If we couldn't find the `table_name` from ALL of the sources, 50 # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition, 51 # we may be able to disambiguate based on the source order. 52 if join_context := self._get_column_join_context(column): 53 # In this case, the return value will be the join that _may_ be able to disambiguate the column 54 # and we can use the source columns available at that join to get the table name 55 # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below 56 try: 57 table_name = self._get_table_name_from_sources( 58 column_name, self._get_available_source_columns(join_context) 59 ) 60 except OptimizeError: 61 pass 62 63 if not table_name and self._infer_schema: 64 sources_without_schema = tuple( 65 source 66 for source, columns in self._get_all_source_columns().items() 67 if not columns or "*" in columns 68 ) 69 if len(sources_without_schema) == 1: 70 table_name = sources_without_schema[0] 71 72 if table_name not in self.scope.selected_sources: 73 return exp.to_identifier(table_name) 74 75 node: exp.Expr = self.scope.selected_sources[table_name][0] 76 77 if isinstance(node, exp.Query): 78 while node and node.alias != table_name and node.parent: 79 node = node.parent 80 81 node_alias = node.args.get("alias") 82 if node_alias: 83 return exp.to_identifier(node_alias.this) 84 85 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column: The column expression (or column name) to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: set[str]
87 @property 88 def all_columns(self) -> set[str]: 89 """All available columns of all sources in this scope""" 90 if self._all_columns is None: 91 self._all_columns = { 92 column for columns in self._get_all_source_columns().values() for column in columns 93 } 94 return self._all_columns
All available columns of all sources in this scope
96 def get_source_columns_from_set_op(self, expression: exp.Expr) -> list[str]: 97 if isinstance(expression, exp.Select): 98 return expression.named_selects 99 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 100 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 101 return self.get_source_columns_from_set_op(expression.this) 102 if not isinstance(expression, exp.SetOperation): 103 raise OptimizeError(f"Unknown set operation: {expression}") 104 105 set_op = expression 106 107 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 108 on_column_list = set_op.args.get("on") 109 110 if on_column_list: 111 # The resulting columns are the columns in the ON clause: 112 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 113 columns = [col.name for col in on_column_list] 114 elif set_op.side or set_op.kind: 115 side = set_op.side 116 kind = set_op.kind 117 118 # Visit the children UNIONs (if any) in a post-order traversal 119 left = self.get_source_columns_from_set_op(set_op.left) 120 right = self.get_source_columns_from_set_op(set_op.right) 121 122 # We use dict.fromkeys to deduplicate keys and maintain insertion order 123 if side == "LEFT": 124 columns = left 125 elif side == "FULL": 126 columns = list(dict.fromkeys(left + right)) 127 elif kind == "INNER": 128 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 129 else: 130 columns = set_op.named_selects 131 132 return columns
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
134 def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]: 135 """Resolve the source columns for a given source `name`.""" 136 cache_key = (name, only_visible) 137 if cache_key not in self._get_source_columns_cache: 138 if name not in self.scope.sources: 139 raise OptimizeError(f"Unknown table: {name}") 140 141 source = self.scope.sources[name] 142 143 if isinstance(source, exp.Table): 144 columns = self.schema.column_names(source, only_visible) 145 elif isinstance(source, Scope) and isinstance( 146 source.expression, (exp.Values, exp.Unnest) 147 ): 148 columns = source.expression.named_selects 149 150 # in bigquery, unnest structs are automatically scoped as tables, so you can 151 # directly select a struct field in a query. 152 # this handles the case where the unnest is statically defined. 153 if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest): 154 unnest = source.expression 155 156 # if type is not annotated yet, try to get it from the schema 157 if not unnest.type or unnest.type.is_type(exp.DType.UNKNOWN): 158 unnest_expr = seq_get(unnest.expressions, 0) 159 if isinstance(unnest_expr, exp.Column) and self.scope.parent: 160 col_type = self._get_unnest_column_type(unnest_expr) 161 # extract element type if it's an ARRAY 162 if col_type and col_type.is_type(exp.DType.ARRAY): 163 element_types = col_type.expressions 164 if element_types: 165 unnest.type = element_types[0].copy() 166 else: 167 if col_type: 168 unnest.type = col_type.copy() 169 # check if the result type is a STRUCT - extract struct field names 170 if unnest.is_type(exp.DType.STRUCT): 171 for k in unnest.type.expressions: # type: ignore 172 columns.append(k.name) 173 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 174 columns = self.get_source_columns_from_set_op(source.expression) 175 else: 176 selectable = source.expression.assert_is(exp.Selectable) 177 select = seq_get(selectable.selects, 0) 178 179 if isinstance(select, exp.QueryTransform): 180 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 181 schema = select.args.get("schema") 182 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 183 else: 184 columns = selectable.named_selects 185 186 node, _ = self.scope.selected_sources.get(name) or (None, None) 187 if isinstance(node, Scope): 188 column_aliases = node.expression.alias_column_names 189 elif isinstance(node, exp.Expr): 190 column_aliases = node.alias_column_names 191 else: 192 column_aliases = [] 193 194 if column_aliases: 195 # If the source's columns are aliased, their aliases shadow the corresponding column names. 196 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 197 columns = [ 198 alias or name 199 for (name, alias) in itertools.zip_longest(columns, column_aliases) 200 ] 201 202 self._get_source_columns_cache[cache_key] = columns 203 204 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name.