sqlglot.optimizer.resolver
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.scope import Scope 11 12if t.TYPE_CHECKING: 13 from sqlglot.schema import Schema 14 15 16class Resolver: 17 """ 18 Helper for resolving columns. 19 20 This is a class so we can lazily load some things and easily share them across functions. 21 """ 22 23 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 24 self.scope = scope 25 self.schema = schema 26 self.dialect = schema.dialect or Dialect() 27 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 28 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 29 self._all_columns: t.Optional[t.Set[str]] = None 30 self._infer_schema = infer_schema 31 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 32 33 def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]: 34 """ 35 Get the table for a column name. 36 37 Args: 38 column: The column expression (or column name) to find the table for. 39 Returns: 40 The table name if it can be found/inferred. 41 """ 42 column_name = column if isinstance(column, str) else column.name 43 44 table_name = self._get_table_name_from_sources(column_name) 45 46 if not table_name and isinstance(column, exp.Column): 47 # Fall-back case: If we couldn't find the `table_name` from ALL of the sources, 48 # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition, 49 # we may be able to disambiguate based on the source order. 50 if join_context := self._get_column_join_context(column): 51 # In this case, the return value will be the join that _may_ be able to disambiguate the column 52 # and we can use the source columns available at that join to get the table name 53 # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below 54 try: 55 table_name = self._get_table_name_from_sources( 56 column_name, self._get_available_source_columns(join_context) 57 ) 58 except OptimizeError: 59 pass 60 61 if not table_name and self._infer_schema: 62 sources_without_schema = tuple( 63 source 64 for source, columns in self._get_all_source_columns().items() 65 if not columns or "*" in columns 66 ) 67 if len(sources_without_schema) == 1: 68 table_name = sources_without_schema[0] 69 70 if table_name not in self.scope.selected_sources: 71 return exp.to_identifier(table_name) 72 73 node, _ = self.scope.selected_sources.get(table_name) 74 75 if isinstance(node, exp.Query): 76 while node and node.alias != table_name: 77 node = node.parent 78 79 node_alias = node.args.get("alias") 80 if node_alias: 81 return exp.to_identifier(node_alias.this) 82 83 return exp.to_identifier(table_name) 84 85 @property 86 def all_columns(self) -> t.Set[str]: 87 """All available columns of all sources in this scope""" 88 if self._all_columns is None: 89 self._all_columns = { 90 column for columns in self._get_all_source_columns().values() for column in columns 91 } 92 return self._all_columns 93 94 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 95 if isinstance(expression, exp.Select): 96 return expression.named_selects 97 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 98 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 99 return self.get_source_columns_from_set_op(expression.this) 100 if not isinstance(expression, exp.SetOperation): 101 raise OptimizeError(f"Unknown set operation: {expression}") 102 103 set_op = expression 104 105 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 106 on_column_list = set_op.args.get("on") 107 108 if on_column_list: 109 # The resulting columns are the columns in the ON clause: 110 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 111 columns = [col.name for col in on_column_list] 112 elif set_op.side or set_op.kind: 113 side = set_op.side 114 kind = set_op.kind 115 116 # Visit the children UNIONs (if any) in a post-order traversal 117 left = self.get_source_columns_from_set_op(set_op.left) 118 right = self.get_source_columns_from_set_op(set_op.right) 119 120 # We use dict.fromkeys to deduplicate keys and maintain insertion order 121 if side == "LEFT": 122 columns = left 123 elif side == "FULL": 124 columns = list(dict.fromkeys(left + right)) 125 elif kind == "INNER": 126 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 127 else: 128 columns = set_op.named_selects 129 130 return columns 131 132 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 133 """Resolve the source columns for a given source `name`.""" 134 cache_key = (name, only_visible) 135 if cache_key not in self._get_source_columns_cache: 136 if name not in self.scope.sources: 137 raise OptimizeError(f"Unknown table: {name}") 138 139 source = self.scope.sources[name] 140 141 if isinstance(source, exp.Table): 142 columns = self.schema.column_names(source, only_visible) 143 elif isinstance(source, Scope) and isinstance( 144 source.expression, (exp.Values, exp.Unnest) 145 ): 146 columns = source.expression.named_selects 147 148 # in bigquery, unnest structs are automatically scoped as tables, so you can 149 # directly select a struct field in a query. 150 # this handles the case where the unnest is statically defined. 151 if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest): 152 unnest = source.expression 153 154 # if type is not annotated yet, try to get it from the schema 155 if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN): 156 unnest_expr = seq_get(unnest.expressions, 0) 157 if isinstance(unnest_expr, exp.Column) and self.scope.parent: 158 col_type = self._get_unnest_column_type(unnest_expr) 159 # extract element type if it's an ARRAY 160 if col_type and col_type.is_type(exp.DataType.Type.ARRAY): 161 element_types = col_type.expressions 162 if element_types: 163 unnest.type = element_types[0].copy() 164 else: 165 if col_type: 166 unnest.type = col_type.copy() 167 # check if the result type is a STRUCT - extract struct field names 168 if unnest.is_type(exp.DataType.Type.STRUCT): 169 for k in unnest.type.expressions: # type: ignore 170 columns.append(k.name) 171 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 172 columns = self.get_source_columns_from_set_op(source.expression) 173 174 else: 175 select = seq_get(source.expression.selects, 0) 176 177 if isinstance(select, exp.QueryTransform): 178 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 179 schema = select.args.get("schema") 180 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 181 else: 182 columns = source.expression.named_selects 183 184 node, _ = self.scope.selected_sources.get(name) or (None, None) 185 if isinstance(node, Scope): 186 column_aliases = node.expression.alias_column_names 187 elif isinstance(node, exp.Expression): 188 column_aliases = node.alias_column_names 189 else: 190 column_aliases = [] 191 192 if column_aliases: 193 # If the source's columns are aliased, their aliases shadow the corresponding column names. 194 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 195 columns = [ 196 alias or name 197 for (name, alias) in itertools.zip_longest(columns, column_aliases) 198 ] 199 200 self._get_source_columns_cache[cache_key] = columns 201 202 return self._get_source_columns_cache[cache_key] 203 204 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 205 if self._source_columns is None: 206 self._source_columns = { 207 source_name: self.get_source_columns(source_name) 208 for source_name, source in itertools.chain( 209 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 210 ) 211 } 212 return self._source_columns 213 214 def _get_table_name_from_sources( 215 self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 216 ) -> t.Optional[str]: 217 if not source_columns: 218 # If not supplied, get all sources to calculate unambiguous columns 219 if self._unambiguous_columns is None: 220 self._unambiguous_columns = self._get_unambiguous_columns( 221 self._get_all_source_columns() 222 ) 223 224 unambiguous_columns = self._unambiguous_columns 225 else: 226 unambiguous_columns = self._get_unambiguous_columns(source_columns) 227 228 return unambiguous_columns.get(column_name) 229 230 def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]: 231 """ 232 Check if a column participating in a join can be qualified based on the source order. 233 """ 234 args = self.scope.expression.args 235 joins = args.get("joins") 236 237 if not joins or args.get("laterals") or args.get("pivots"): 238 # Feature gap: We currently don't try to disambiguate columns if other sources 239 # (e.g laterals, pivots) exist alongside joins 240 return None 241 242 join_ancestor = column.find_ancestor(exp.Join, exp.Select) 243 244 if ( 245 isinstance(join_ancestor, exp.Join) 246 and join_ancestor.alias_or_name in self.scope.selected_sources 247 ): 248 # Ensure that the found ancestor is a join that contains an actual source, 249 # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b` 250 return join_ancestor 251 252 return None 253 254 def _get_available_source_columns( 255 self, join_ancestor: exp.Join 256 ) -> t.Dict[str, t.Sequence[str]]: 257 """ 258 Get the source columns that are available at the point where a column is referenced. 259 260 For columns in JOIN conditions, this only includes tables that have been joined 261 up to that point. Example: 262 263 ``` 264 SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ... 265 ``` ^ 266 | 267 +----------------------------------+ 268 | 269 ⌄ 270 The unqualified column `c` is not ambiguous if no other sources up until that 271 join i.e t_1, ..., t_n, contain a column named `c`. 272 273 """ 274 args = self.scope.expression.args 275 276 # Collect tables in order: FROM clause tables + joined tables up to current join 277 from_name = args["from_"].alias_or_name 278 available_sources = {from_name: self.get_source_columns(from_name)} 279 280 for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]: 281 available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name) 282 283 return available_sources 284 285 def _get_unambiguous_columns( 286 self, source_columns: t.Dict[str, t.Sequence[str]] 287 ) -> t.Mapping[str, str]: 288 """ 289 Find all the unambiguous columns in sources. 290 291 Args: 292 source_columns: Mapping of names to source columns. 293 294 Returns: 295 Mapping of column name to source name. 296 """ 297 if not source_columns: 298 return {} 299 300 source_columns_pairs = list(source_columns.items()) 301 302 first_table, first_columns = source_columns_pairs[0] 303 304 if len(source_columns_pairs) == 1: 305 # Performance optimization - avoid copying first_columns if there is only one table. 306 return SingleValuedMapping(first_columns, first_table) 307 308 # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases 309 # from alias.columns[0] to their source names. This is used to resolve shadowing 310 # where an UNNEST alias shadows a column name from another table. 311 unnest_original_aliases: t.Dict[str, str] = {} 312 if self.dialect.UNNEST_COLUMN_ONLY: 313 unnest_original_aliases = { 314 alias_arg.columns[0].name: source_name 315 for source_name, source in self.scope.sources.items() 316 if ( 317 isinstance(source.expression, exp.Unnest) 318 and (alias_arg := source.expression.args.get("alias")) 319 and alias_arg.columns 320 ) 321 } 322 323 unambiguous_columns = {col: first_table for col in first_columns} 324 all_columns = set(unambiguous_columns) 325 326 for table, columns in source_columns_pairs[1:]: 327 unique = set(columns) 328 ambiguous = all_columns.intersection(unique) 329 all_columns.update(columns) 330 331 for column in ambiguous: 332 if column in unnest_original_aliases: 333 unambiguous_columns[column] = unnest_original_aliases[column] 334 continue 335 336 unambiguous_columns.pop(column, None) 337 for column in unique.difference(ambiguous): 338 unambiguous_columns[column] = table 339 340 return unambiguous_columns 341 342 def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]: 343 """ 344 Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table. 345 346 Args: 347 column: The column expression being unnested. 348 349 Returns: 350 The DataType of the column, or None if not found. 351 """ 352 scope = self.scope.parent 353 354 # if column is qualified, use that table, otherwise disambiguate using the resolver 355 if column.table: 356 table_name = column.table 357 else: 358 # use the parent scope's resolver to disambiguate the column 359 parent_resolver = Resolver(scope, self.schema, self._infer_schema) 360 table_identifier = parent_resolver.get_table(column) 361 if not table_identifier: 362 return None 363 table_name = table_identifier.name 364 365 source = scope.sources.get(table_name) 366 return self._get_column_type_from_scope(source, column) if source else None 367 368 def _get_column_type_from_scope( 369 self, source: t.Union[Scope, exp.Table], column: exp.Column 370 ) -> t.Optional[exp.DataType]: 371 """ 372 Get a column's type by tracing through scopes/tables to find the base table. 373 374 Args: 375 source: The source to search - can be a Scope (to iterate its sources) or a Table. 376 column: The column to find the type for. 377 378 Returns: 379 The DataType of the column, or None if not found. 380 """ 381 if isinstance(source, exp.Table): 382 # base table - get the column type from schema 383 col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column) 384 if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): 385 return col_type 386 elif isinstance(source, Scope): 387 # iterate over all sources in the scope 388 for source_name, nested_source in source.sources.items(): 389 col_type = self._get_column_type_from_scope(nested_source, column) 390 if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): 391 return col_type 392 393 return None
class
Resolver:
17class Resolver: 18 """ 19 Helper for resolving columns. 20 21 This is a class so we can lazily load some things and easily share them across functions. 22 """ 23 24 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 25 self.scope = scope 26 self.schema = schema 27 self.dialect = schema.dialect or Dialect() 28 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 29 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 30 self._all_columns: t.Optional[t.Set[str]] = None 31 self._infer_schema = infer_schema 32 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 33 34 def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]: 35 """ 36 Get the table for a column name. 37 38 Args: 39 column: The column expression (or column name) to find the table for. 40 Returns: 41 The table name if it can be found/inferred. 42 """ 43 column_name = column if isinstance(column, str) else column.name 44 45 table_name = self._get_table_name_from_sources(column_name) 46 47 if not table_name and isinstance(column, exp.Column): 48 # Fall-back case: If we couldn't find the `table_name` from ALL of the sources, 49 # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition, 50 # we may be able to disambiguate based on the source order. 51 if join_context := self._get_column_join_context(column): 52 # In this case, the return value will be the join that _may_ be able to disambiguate the column 53 # and we can use the source columns available at that join to get the table name 54 # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below 55 try: 56 table_name = self._get_table_name_from_sources( 57 column_name, self._get_available_source_columns(join_context) 58 ) 59 except OptimizeError: 60 pass 61 62 if not table_name and self._infer_schema: 63 sources_without_schema = tuple( 64 source 65 for source, columns in self._get_all_source_columns().items() 66 if not columns or "*" in columns 67 ) 68 if len(sources_without_schema) == 1: 69 table_name = sources_without_schema[0] 70 71 if table_name not in self.scope.selected_sources: 72 return exp.to_identifier(table_name) 73 74 node, _ = self.scope.selected_sources.get(table_name) 75 76 if isinstance(node, exp.Query): 77 while node and node.alias != table_name: 78 node = node.parent 79 80 node_alias = node.args.get("alias") 81 if node_alias: 82 return exp.to_identifier(node_alias.this) 83 84 return exp.to_identifier(table_name) 85 86 @property 87 def all_columns(self) -> t.Set[str]: 88 """All available columns of all sources in this scope""" 89 if self._all_columns is None: 90 self._all_columns = { 91 column for columns in self._get_all_source_columns().values() for column in columns 92 } 93 return self._all_columns 94 95 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 96 if isinstance(expression, exp.Select): 97 return expression.named_selects 98 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 99 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 100 return self.get_source_columns_from_set_op(expression.this) 101 if not isinstance(expression, exp.SetOperation): 102 raise OptimizeError(f"Unknown set operation: {expression}") 103 104 set_op = expression 105 106 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 107 on_column_list = set_op.args.get("on") 108 109 if on_column_list: 110 # The resulting columns are the columns in the ON clause: 111 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 112 columns = [col.name for col in on_column_list] 113 elif set_op.side or set_op.kind: 114 side = set_op.side 115 kind = set_op.kind 116 117 # Visit the children UNIONs (if any) in a post-order traversal 118 left = self.get_source_columns_from_set_op(set_op.left) 119 right = self.get_source_columns_from_set_op(set_op.right) 120 121 # We use dict.fromkeys to deduplicate keys and maintain insertion order 122 if side == "LEFT": 123 columns = left 124 elif side == "FULL": 125 columns = list(dict.fromkeys(left + right)) 126 elif kind == "INNER": 127 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 128 else: 129 columns = set_op.named_selects 130 131 return columns 132 133 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 134 """Resolve the source columns for a given source `name`.""" 135 cache_key = (name, only_visible) 136 if cache_key not in self._get_source_columns_cache: 137 if name not in self.scope.sources: 138 raise OptimizeError(f"Unknown table: {name}") 139 140 source = self.scope.sources[name] 141 142 if isinstance(source, exp.Table): 143 columns = self.schema.column_names(source, only_visible) 144 elif isinstance(source, Scope) and isinstance( 145 source.expression, (exp.Values, exp.Unnest) 146 ): 147 columns = source.expression.named_selects 148 149 # in bigquery, unnest structs are automatically scoped as tables, so you can 150 # directly select a struct field in a query. 151 # this handles the case where the unnest is statically defined. 152 if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest): 153 unnest = source.expression 154 155 # if type is not annotated yet, try to get it from the schema 156 if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN): 157 unnest_expr = seq_get(unnest.expressions, 0) 158 if isinstance(unnest_expr, exp.Column) and self.scope.parent: 159 col_type = self._get_unnest_column_type(unnest_expr) 160 # extract element type if it's an ARRAY 161 if col_type and col_type.is_type(exp.DataType.Type.ARRAY): 162 element_types = col_type.expressions 163 if element_types: 164 unnest.type = element_types[0].copy() 165 else: 166 if col_type: 167 unnest.type = col_type.copy() 168 # check if the result type is a STRUCT - extract struct field names 169 if unnest.is_type(exp.DataType.Type.STRUCT): 170 for k in unnest.type.expressions: # type: ignore 171 columns.append(k.name) 172 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 173 columns = self.get_source_columns_from_set_op(source.expression) 174 175 else: 176 select = seq_get(source.expression.selects, 0) 177 178 if isinstance(select, exp.QueryTransform): 179 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 180 schema = select.args.get("schema") 181 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 182 else: 183 columns = source.expression.named_selects 184 185 node, _ = self.scope.selected_sources.get(name) or (None, None) 186 if isinstance(node, Scope): 187 column_aliases = node.expression.alias_column_names 188 elif isinstance(node, exp.Expression): 189 column_aliases = node.alias_column_names 190 else: 191 column_aliases = [] 192 193 if column_aliases: 194 # If the source's columns are aliased, their aliases shadow the corresponding column names. 195 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 196 columns = [ 197 alias or name 198 for (name, alias) in itertools.zip_longest(columns, column_aliases) 199 ] 200 201 self._get_source_columns_cache[cache_key] = columns 202 203 return self._get_source_columns_cache[cache_key] 204 205 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 206 if self._source_columns is None: 207 self._source_columns = { 208 source_name: self.get_source_columns(source_name) 209 for source_name, source in itertools.chain( 210 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 211 ) 212 } 213 return self._source_columns 214 215 def _get_table_name_from_sources( 216 self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 217 ) -> t.Optional[str]: 218 if not source_columns: 219 # If not supplied, get all sources to calculate unambiguous columns 220 if self._unambiguous_columns is None: 221 self._unambiguous_columns = self._get_unambiguous_columns( 222 self._get_all_source_columns() 223 ) 224 225 unambiguous_columns = self._unambiguous_columns 226 else: 227 unambiguous_columns = self._get_unambiguous_columns(source_columns) 228 229 return unambiguous_columns.get(column_name) 230 231 def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]: 232 """ 233 Check if a column participating in a join can be qualified based on the source order. 234 """ 235 args = self.scope.expression.args 236 joins = args.get("joins") 237 238 if not joins or args.get("laterals") or args.get("pivots"): 239 # Feature gap: We currently don't try to disambiguate columns if other sources 240 # (e.g laterals, pivots) exist alongside joins 241 return None 242 243 join_ancestor = column.find_ancestor(exp.Join, exp.Select) 244 245 if ( 246 isinstance(join_ancestor, exp.Join) 247 and join_ancestor.alias_or_name in self.scope.selected_sources 248 ): 249 # Ensure that the found ancestor is a join that contains an actual source, 250 # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b` 251 return join_ancestor 252 253 return None 254 255 def _get_available_source_columns( 256 self, join_ancestor: exp.Join 257 ) -> t.Dict[str, t.Sequence[str]]: 258 """ 259 Get the source columns that are available at the point where a column is referenced. 260 261 For columns in JOIN conditions, this only includes tables that have been joined 262 up to that point. Example: 263 264 ``` 265 SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ... 266 ``` ^ 267 | 268 +----------------------------------+ 269 | 270 ⌄ 271 The unqualified column `c` is not ambiguous if no other sources up until that 272 join i.e t_1, ..., t_n, contain a column named `c`. 273 274 """ 275 args = self.scope.expression.args 276 277 # Collect tables in order: FROM clause tables + joined tables up to current join 278 from_name = args["from_"].alias_or_name 279 available_sources = {from_name: self.get_source_columns(from_name)} 280 281 for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]: 282 available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name) 283 284 return available_sources 285 286 def _get_unambiguous_columns( 287 self, source_columns: t.Dict[str, t.Sequence[str]] 288 ) -> t.Mapping[str, str]: 289 """ 290 Find all the unambiguous columns in sources. 291 292 Args: 293 source_columns: Mapping of names to source columns. 294 295 Returns: 296 Mapping of column name to source name. 297 """ 298 if not source_columns: 299 return {} 300 301 source_columns_pairs = list(source_columns.items()) 302 303 first_table, first_columns = source_columns_pairs[0] 304 305 if len(source_columns_pairs) == 1: 306 # Performance optimization - avoid copying first_columns if there is only one table. 307 return SingleValuedMapping(first_columns, first_table) 308 309 # For BigQuery UNNEST_COLUMN_ONLY, build a mapping of original UNNEST aliases 310 # from alias.columns[0] to their source names. This is used to resolve shadowing 311 # where an UNNEST alias shadows a column name from another table. 312 unnest_original_aliases: t.Dict[str, str] = {} 313 if self.dialect.UNNEST_COLUMN_ONLY: 314 unnest_original_aliases = { 315 alias_arg.columns[0].name: source_name 316 for source_name, source in self.scope.sources.items() 317 if ( 318 isinstance(source.expression, exp.Unnest) 319 and (alias_arg := source.expression.args.get("alias")) 320 and alias_arg.columns 321 ) 322 } 323 324 unambiguous_columns = {col: first_table for col in first_columns} 325 all_columns = set(unambiguous_columns) 326 327 for table, columns in source_columns_pairs[1:]: 328 unique = set(columns) 329 ambiguous = all_columns.intersection(unique) 330 all_columns.update(columns) 331 332 for column in ambiguous: 333 if column in unnest_original_aliases: 334 unambiguous_columns[column] = unnest_original_aliases[column] 335 continue 336 337 unambiguous_columns.pop(column, None) 338 for column in unique.difference(ambiguous): 339 unambiguous_columns[column] = table 340 341 return unambiguous_columns 342 343 def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]: 344 """ 345 Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table. 346 347 Args: 348 column: The column expression being unnested. 349 350 Returns: 351 The DataType of the column, or None if not found. 352 """ 353 scope = self.scope.parent 354 355 # if column is qualified, use that table, otherwise disambiguate using the resolver 356 if column.table: 357 table_name = column.table 358 else: 359 # use the parent scope's resolver to disambiguate the column 360 parent_resolver = Resolver(scope, self.schema, self._infer_schema) 361 table_identifier = parent_resolver.get_table(column) 362 if not table_identifier: 363 return None 364 table_name = table_identifier.name 365 366 source = scope.sources.get(table_name) 367 return self._get_column_type_from_scope(source, column) if source else None 368 369 def _get_column_type_from_scope( 370 self, source: t.Union[Scope, exp.Table], column: exp.Column 371 ) -> t.Optional[exp.DataType]: 372 """ 373 Get a column's type by tracing through scopes/tables to find the base table. 374 375 Args: 376 source: The source to search - can be a Scope (to iterate its sources) or a Table. 377 column: The column to find the type for. 378 379 Returns: 380 The DataType of the column, or None if not found. 381 """ 382 if isinstance(source, exp.Table): 383 # base table - get the column type from schema 384 col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column) 385 if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): 386 return col_type 387 elif isinstance(source, Scope): 388 # iterate over all sources in the scope 389 for source_name, nested_source in source.sources.items(): 390 col_type = self._get_column_type_from_scope(nested_source, column) 391 if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): 392 return col_type 393 394 return None
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
24 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 25 self.scope = scope 26 self.schema = schema 27 self.dialect = schema.dialect or Dialect() 28 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 29 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 30 self._all_columns: t.Optional[t.Set[str]] = None 31 self._infer_schema = infer_schema 32 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
def
get_table( self, column: str | sqlglot.expressions.Column) -> Optional[sqlglot.expressions.Identifier]:
34 def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]: 35 """ 36 Get the table for a column name. 37 38 Args: 39 column: The column expression (or column name) to find the table for. 40 Returns: 41 The table name if it can be found/inferred. 42 """ 43 column_name = column if isinstance(column, str) else column.name 44 45 table_name = self._get_table_name_from_sources(column_name) 46 47 if not table_name and isinstance(column, exp.Column): 48 # Fall-back case: If we couldn't find the `table_name` from ALL of the sources, 49 # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition, 50 # we may be able to disambiguate based on the source order. 51 if join_context := self._get_column_join_context(column): 52 # In this case, the return value will be the join that _may_ be able to disambiguate the column 53 # and we can use the source columns available at that join to get the table name 54 # catch OptimizeError if column is still ambiguous and try to resolve with schema inference below 55 try: 56 table_name = self._get_table_name_from_sources( 57 column_name, self._get_available_source_columns(join_context) 58 ) 59 except OptimizeError: 60 pass 61 62 if not table_name and self._infer_schema: 63 sources_without_schema = tuple( 64 source 65 for source, columns in self._get_all_source_columns().items() 66 if not columns or "*" in columns 67 ) 68 if len(sources_without_schema) == 1: 69 table_name = sources_without_schema[0] 70 71 if table_name not in self.scope.selected_sources: 72 return exp.to_identifier(table_name) 73 74 node, _ = self.scope.selected_sources.get(table_name) 75 76 if isinstance(node, exp.Query): 77 while node and node.alias != table_name: 78 node = node.parent 79 80 node_alias = node.args.get("alias") 81 if node_alias: 82 return exp.to_identifier(node_alias.this) 83 84 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column: The column expression (or column name) to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
86 @property 87 def all_columns(self) -> t.Set[str]: 88 """All available columns of all sources in this scope""" 89 if self._all_columns is None: 90 self._all_columns = { 91 column for columns in self._get_all_source_columns().values() for column in columns 92 } 93 return self._all_columns
All available columns of all sources in this scope
95 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 96 if isinstance(expression, exp.Select): 97 return expression.named_selects 98 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 99 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 100 return self.get_source_columns_from_set_op(expression.this) 101 if not isinstance(expression, exp.SetOperation): 102 raise OptimizeError(f"Unknown set operation: {expression}") 103 104 set_op = expression 105 106 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 107 on_column_list = set_op.args.get("on") 108 109 if on_column_list: 110 # The resulting columns are the columns in the ON clause: 111 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 112 columns = [col.name for col in on_column_list] 113 elif set_op.side or set_op.kind: 114 side = set_op.side 115 kind = set_op.kind 116 117 # Visit the children UNIONs (if any) in a post-order traversal 118 left = self.get_source_columns_from_set_op(set_op.left) 119 right = self.get_source_columns_from_set_op(set_op.right) 120 121 # We use dict.fromkeys to deduplicate keys and maintain insertion order 122 if side == "LEFT": 123 columns = left 124 elif side == "FULL": 125 columns = list(dict.fromkeys(left + right)) 126 elif kind == "INNER": 127 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 128 else: 129 columns = set_op.named_selects 130 131 return columns
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
133 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 134 """Resolve the source columns for a given source `name`.""" 135 cache_key = (name, only_visible) 136 if cache_key not in self._get_source_columns_cache: 137 if name not in self.scope.sources: 138 raise OptimizeError(f"Unknown table: {name}") 139 140 source = self.scope.sources[name] 141 142 if isinstance(source, exp.Table): 143 columns = self.schema.column_names(source, only_visible) 144 elif isinstance(source, Scope) and isinstance( 145 source.expression, (exp.Values, exp.Unnest) 146 ): 147 columns = source.expression.named_selects 148 149 # in bigquery, unnest structs are automatically scoped as tables, so you can 150 # directly select a struct field in a query. 151 # this handles the case where the unnest is statically defined. 152 if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest): 153 unnest = source.expression 154 155 # if type is not annotated yet, try to get it from the schema 156 if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN): 157 unnest_expr = seq_get(unnest.expressions, 0) 158 if isinstance(unnest_expr, exp.Column) and self.scope.parent: 159 col_type = self._get_unnest_column_type(unnest_expr) 160 # extract element type if it's an ARRAY 161 if col_type and col_type.is_type(exp.DataType.Type.ARRAY): 162 element_types = col_type.expressions 163 if element_types: 164 unnest.type = element_types[0].copy() 165 else: 166 if col_type: 167 unnest.type = col_type.copy() 168 # check if the result type is a STRUCT - extract struct field names 169 if unnest.is_type(exp.DataType.Type.STRUCT): 170 for k in unnest.type.expressions: # type: ignore 171 columns.append(k.name) 172 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 173 columns = self.get_source_columns_from_set_op(source.expression) 174 175 else: 176 select = seq_get(source.expression.selects, 0) 177 178 if isinstance(select, exp.QueryTransform): 179 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 180 schema = select.args.get("schema") 181 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 182 else: 183 columns = source.expression.named_selects 184 185 node, _ = self.scope.selected_sources.get(name) or (None, None) 186 if isinstance(node, Scope): 187 column_aliases = node.expression.alias_column_names 188 elif isinstance(node, exp.Expression): 189 column_aliases = node.alias_column_names 190 else: 191 column_aliases = [] 192 193 if column_aliases: 194 # If the source's columns are aliased, their aliases shadow the corresponding column names. 195 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 196 columns = [ 197 alias or name 198 for (name, alias) in itertools.zip_longest(columns, column_aliases) 199 ] 200 201 self._get_source_columns_cache[cache_key] = columns 202 203 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name.