sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26 dialect: DialectType = None, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema, dialect=dialect) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 bigquery = dialect == "bigquery" 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=bigquery, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if bigquery: 111 annotator.annotate_scope(scope) 112 113 return expression 114 115 116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression 141 142 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 144 name_columns = [ 145 field.this 146 for field in unpivot.fields 147 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 148 ] 149 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 150 151 return itertools.chain(name_columns, value_columns) 152 153 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 155 """ 156 Remove table column aliases. 157 158 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 159 """ 160 for derived_table in derived_tables: 161 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 162 continue 163 table_alias = derived_table.args.get("alias") 164 if table_alias: 165 table_alias.args.pop("columns", None) 166 167 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 169 columns = {} 170 171 def _update_source_columns(source_name: str) -> None: 172 for column_name in resolver.get_source_columns(source_name): 173 if column_name not in columns: 174 columns[column_name] = source_name 175 176 joins = list(scope.find_all(exp.Join)) 177 names = {join.alias_or_name for join in joins} 178 ordered = [key for key in scope.selected_sources if key not in names] 179 180 if names and not ordered: 181 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 182 183 # Mapping of automatically joined column names to an ordered set of source names (dict). 184 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 185 186 for source_name in ordered: 187 _update_source_columns(source_name) 188 189 for i, join in enumerate(joins): 190 source_table = ordered[-1] 191 if source_table: 192 _update_source_columns(source_table) 193 194 join_table = join.alias_or_name 195 ordered.append(join_table) 196 197 using = join.args.get("using") 198 if not using: 199 continue 200 201 join_columns = resolver.get_source_columns(join_table) 202 conditions = [] 203 using_identifier_count = len(using) 204 is_semi_or_anti_join = join.is_semi_or_anti_join 205 206 for identifier in using: 207 identifier = identifier.name 208 table = columns.get(identifier) 209 210 if not table or identifier not in join_columns: 211 if (columns and "*" not in columns) and join_columns: 212 raise OptimizeError(f"Cannot automatically join: {identifier}") 213 214 table = table or source_table 215 216 if i == 0 or using_identifier_count == 1: 217 lhs: exp.Expression = exp.column(identifier, table=table) 218 else: 219 coalesce_columns = [ 220 exp.column(identifier, table=t) 221 for t in ordered[:-1] 222 if identifier in resolver.get_source_columns(t) 223 ] 224 if len(coalesce_columns) > 1: 225 lhs = exp.func("coalesce", *coalesce_columns) 226 else: 227 lhs = exp.column(identifier, table=table) 228 229 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 230 231 # Set all values in the dict to None, because we only care about the key ordering 232 tables = column_tables.setdefault(identifier, {}) 233 234 # Do not update the dict if this was a SEMI/ANTI join in 235 # order to avoid generating COALESCE columns for this join pair 236 if not is_semi_or_anti_join: 237 if table not in tables: 238 tables[table] = None 239 if join_table not in tables: 240 tables[join_table] = None 241 242 join.args.pop("using") 243 join.set("on", exp.and_(*conditions, copy=False)) 244 245 if column_tables: 246 for column in scope.columns: 247 if not column.table and column.name in column_tables: 248 tables = column_tables[column.name] 249 coalesce_args = [exp.column(column.name, table=table) for table in tables] 250 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 251 252 if isinstance(column.parent, exp.Select): 253 # Ensure the USING column keeps its name if it's projected 254 replacement = alias(replacement, alias=column.name, copy=False) 255 elif isinstance(column.parent, exp.Struct): 256 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 257 replacement = exp.PropertyEQ( 258 this=exp.to_identifier(column.name), expression=replacement 259 ) 260 261 scope.replace(column, replacement) 262 263 return column_tables 264 265 266def _expand_alias_refs( 267 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 268) -> None: 269 """ 270 Expand references to aliases. 271 Example: 272 SELECT y.foo AS bar, bar * 2 AS baz FROM y 273 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 274 """ 275 expression = scope.expression 276 277 if not isinstance(expression, exp.Select) or dialect == "oracle": 278 return 279 280 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 281 projections = {s.alias_or_name for s in expression.selects} 282 283 def replace_columns( 284 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 285 ) -> None: 286 is_group_by = isinstance(node, exp.Group) 287 is_having = isinstance(node, exp.Having) 288 if not node or (expand_only_groupby and not is_group_by): 289 return 290 291 for column in walk_in_scope(node, prune=lambda node: node.is_star): 292 if not isinstance(column, exp.Column): 293 continue 294 295 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 296 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 297 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 298 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 299 if expand_only_groupby and is_group_by and column.parent is not node: 300 continue 301 302 skip_replace = False 303 table = resolver.get_table(column.name) if resolve_table and not column.table else None 304 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 305 306 if alias_expr: 307 skip_replace = bool( 308 alias_expr.find(exp.AggFunc) 309 and column.find_ancestor(exp.AggFunc) 310 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 311 ) 312 313 # BigQuery's having clause gets confused if an alias matches a source. 314 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 315 # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table 316 if is_having and dialect == "bigquery": 317 skip_replace = skip_replace or any( 318 node.parts[0].name in projections 319 for node in alias_expr.find_all(exp.Column) 320 ) 321 322 if table and (not alias_expr or skip_replace): 323 column.set("table", table) 324 elif not column.table and alias_expr and not skip_replace: 325 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 326 if literal_index: 327 column.replace(exp.Literal.number(i)) 328 else: 329 column = column.replace(exp.paren(alias_expr)) 330 simplified = simplify_parens(column) 331 if simplified is not column: 332 column.replace(simplified) 333 334 for i, projection in enumerate(expression.selects): 335 replace_columns(projection) 336 if isinstance(projection, exp.Alias): 337 alias_to_expression[projection.alias] = (projection.this, i + 1) 338 339 parent_scope = scope 340 while parent_scope.is_union: 341 parent_scope = parent_scope.parent 342 343 # We shouldn't expand aliases if they match the recursive CTE's columns 344 if parent_scope.is_cte: 345 cte = parent_scope.expression.parent 346 if cte.find_ancestor(exp.With).recursive: 347 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 348 alias_to_expression.pop(recursive_cte_column.output_name, None) 349 350 replace_columns(expression.args.get("where")) 351 replace_columns(expression.args.get("group"), literal_index=True) 352 replace_columns(expression.args.get("having"), resolve_table=True) 353 replace_columns(expression.args.get("qualify"), resolve_table=True) 354 355 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 356 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 357 if dialect == "snowflake": 358 for join in expression.args.get("joins") or []: 359 replace_columns(join) 360 361 scope.clear_cache() 362 363 364def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 365 expression = scope.expression 366 group = expression.args.get("group") 367 if not group: 368 return 369 370 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 371 expression.set("group", group) 372 373 374def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 375 for modifier_key in ("order", "distinct"): 376 modifier = scope.expression.args.get(modifier_key) 377 if isinstance(modifier, exp.Distinct): 378 modifier = modifier.args.get("on") 379 380 if not isinstance(modifier, exp.Expression): 381 continue 382 383 modifier_expressions = modifier.expressions 384 if modifier_key == "order": 385 modifier_expressions = [ordered.this for ordered in modifier_expressions] 386 387 for original, expanded in zip( 388 modifier_expressions, 389 _expand_positional_references( 390 scope, modifier_expressions, resolver.schema.dialect, alias=True 391 ), 392 ): 393 for agg in original.find_all(exp.AggFunc): 394 for col in agg.find_all(exp.Column): 395 if not col.table: 396 col.set("table", resolver.get_table(col.name)) 397 398 original.replace(expanded) 399 400 if scope.expression.args.get("group"): 401 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 402 403 for expression in modifier_expressions: 404 expression.replace( 405 exp.to_identifier(_select_by_pos(scope, expression).alias) 406 if expression.is_int 407 else selects.get(expression, expression) 408 ) 409 410 411def _expand_positional_references( 412 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 413) -> t.List[exp.Expression]: 414 new_nodes: t.List[exp.Expression] = [] 415 ambiguous_projections = None 416 417 for node in expressions: 418 if node.is_int: 419 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 420 421 if alias: 422 new_nodes.append(exp.column(select.args["alias"].copy())) 423 else: 424 select = select.this 425 426 if dialect == "bigquery": 427 if ambiguous_projections is None: 428 # When a projection name is also a source name and it is referenced in the 429 # GROUP BY clause, BQ can't understand what the identifier corresponds to 430 ambiguous_projections = { 431 s.alias_or_name 432 for s in scope.expression.selects 433 if s.alias_or_name in scope.selected_sources 434 } 435 436 ambiguous = any( 437 column.parts[0].name in ambiguous_projections 438 for column in select.find_all(exp.Column) 439 ) 440 else: 441 ambiguous = False 442 443 if ( 444 isinstance(select, exp.CONSTANTS) 445 or select.find(exp.Explode, exp.Unnest) 446 or ambiguous 447 ): 448 new_nodes.append(node) 449 else: 450 new_nodes.append(select.copy()) 451 else: 452 new_nodes.append(node) 453 454 return new_nodes 455 456 457def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 458 try: 459 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 460 except IndexError: 461 raise OptimizeError(f"Unknown output column: {node.name}") 462 463 464def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 465 """ 466 Converts `Column` instances that represent struct field lookup into chained `Dots`. 467 468 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 469 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 470 """ 471 converted = False 472 for column in itertools.chain(scope.columns, scope.stars): 473 if isinstance(column, exp.Dot): 474 continue 475 476 column_table: t.Optional[str | exp.Identifier] = column.table 477 if ( 478 column_table 479 and column_table not in scope.sources 480 and ( 481 not scope.parent 482 or column_table not in scope.parent.sources 483 or not scope.is_correlated_subquery 484 ) 485 ): 486 root, *parts = column.parts 487 488 if root.name in scope.sources: 489 # The struct is already qualified, but we still need to change the AST 490 column_table = root 491 root, *parts = parts 492 else: 493 column_table = resolver.get_table(root.name) 494 495 if column_table: 496 converted = True 497 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 498 499 if converted: 500 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 501 # a `for column in scope.columns` iteration, even though they shouldn't be 502 scope.clear_cache() 503 504 505def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 506 """Disambiguate columns, ensuring each column specifies a source""" 507 for column in scope.columns: 508 column_table = column.table 509 column_name = column.name 510 511 if column_table and column_table in scope.sources: 512 source_columns = resolver.get_source_columns(column_table) 513 if ( 514 not allow_partial_qualification 515 and source_columns 516 and column_name not in source_columns 517 and "*" not in source_columns 518 ): 519 raise OptimizeError(f"Unknown column: {column_name}") 520 521 if not column_table: 522 if scope.pivots and not column.find_ancestor(exp.Pivot): 523 # If the column is under the Pivot expression, we need to qualify it 524 # using the name of the pivoted source instead of the pivot's alias 525 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 526 continue 527 528 # column_table can be a '' because bigquery unnest has no table alias 529 column_table = resolver.get_table(column_name) 530 if column_table: 531 column.set("table", column_table) 532 elif ( 533 resolver.schema.dialect == "bigquery" 534 and len(column.parts) == 1 535 and column_name in scope.selected_sources 536 ): 537 # BigQuery allows tables to be referenced as columns, treating them as structs 538 scope.replace(column, exp.TableColumn(this=column.this)) 539 540 for pivot in scope.pivots: 541 for column in pivot.find_all(exp.Column): 542 if not column.table and column.name in resolver.all_columns: 543 column_table = resolver.get_table(column.name) 544 if column_table: 545 column.set("table", column_table) 546 547 548def _expand_struct_stars( 549 expression: exp.Dot, 550) -> t.List[exp.Alias]: 551 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 552 553 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 554 if not dot_column.is_type(exp.DataType.Type.STRUCT): 555 return [] 556 557 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 558 dot_column = dot_column.copy() 559 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 560 561 # First part is the table name and last part is the star so they can be dropped 562 dot_parts = expression.parts[1:-1] 563 564 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 565 for part in dot_parts[1:]: 566 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 567 # Unable to expand star unless all fields are named 568 if not isinstance(field.this, exp.Identifier): 569 return [] 570 571 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 572 starting_struct = field 573 break 574 else: 575 # There is no matching field in the struct 576 return [] 577 578 taken_names = set() 579 new_selections = [] 580 581 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 582 name = field.name 583 584 # Ambiguous or anonymous fields can't be expanded 585 if name in taken_names or not isinstance(field.this, exp.Identifier): 586 return [] 587 588 taken_names.add(name) 589 590 this = field.this.copy() 591 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 592 new_column = exp.column( 593 t.cast(exp.Identifier, root), 594 table=dot_column.args.get("table"), 595 fields=t.cast(t.List[exp.Identifier], parts), 596 ) 597 new_selections.append(alias(new_column, this, copy=False)) 598 599 return new_selections 600 601 602def _expand_stars( 603 scope: Scope, 604 resolver: Resolver, 605 using_column_tables: t.Dict[str, t.Any], 606 pseudocolumns: t.Set[str], 607 annotator: TypeAnnotator, 608) -> None: 609 """Expand stars to lists of column selections""" 610 611 new_selections: t.List[exp.Expression] = [] 612 except_columns: t.Dict[int, t.Set[str]] = {} 613 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 614 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 615 616 coalesced_columns = set() 617 dialect = resolver.schema.dialect 618 619 pivot_output_columns = None 620 pivot_exclude_columns: t.Set[str] = set() 621 622 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 623 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 624 if pivot.unpivot: 625 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 626 627 for field in pivot.fields: 628 if isinstance(field, exp.In): 629 pivot_exclude_columns.update( 630 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 631 ) 632 633 else: 634 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 635 636 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 637 if not pivot_output_columns: 638 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 639 640 is_bigquery = dialect == "bigquery" 641 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 642 # Found struct expansion, annotate scope ahead of time 643 annotator.annotate_scope(scope) 644 645 for expression in scope.expression.selects: 646 tables = [] 647 if isinstance(expression, exp.Star): 648 tables.extend(scope.selected_sources) 649 _add_except_columns(expression, tables, except_columns) 650 _add_replace_columns(expression, tables, replace_columns) 651 _add_rename_columns(expression, tables, rename_columns) 652 elif expression.is_star: 653 if not isinstance(expression, exp.Dot): 654 tables.append(expression.table) 655 _add_except_columns(expression.this, tables, except_columns) 656 _add_replace_columns(expression.this, tables, replace_columns) 657 _add_rename_columns(expression.this, tables, rename_columns) 658 elif is_bigquery: 659 struct_fields = _expand_struct_stars(expression) 660 if struct_fields: 661 new_selections.extend(struct_fields) 662 continue 663 664 if not tables: 665 new_selections.append(expression) 666 continue 667 668 for table in tables: 669 if table not in scope.sources: 670 raise OptimizeError(f"Unknown table: {table}") 671 672 columns = resolver.get_source_columns(table, only_visible=True) 673 columns = columns or scope.outer_columns 674 675 if pseudocolumns: 676 columns = [name for name in columns if name.upper() not in pseudocolumns] 677 678 if not columns or "*" in columns: 679 return 680 681 table_id = id(table) 682 columns_to_exclude = except_columns.get(table_id) or set() 683 renamed_columns = rename_columns.get(table_id, {}) 684 replaced_columns = replace_columns.get(table_id, {}) 685 686 if pivot: 687 if pivot_output_columns and pivot_exclude_columns: 688 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 689 pivot_columns.extend(pivot_output_columns) 690 else: 691 pivot_columns = pivot.alias_column_names 692 693 if pivot_columns: 694 new_selections.extend( 695 alias(exp.column(name, table=pivot.alias), name, copy=False) 696 for name in pivot_columns 697 if name not in columns_to_exclude 698 ) 699 continue 700 701 for name in columns: 702 if name in columns_to_exclude or name in coalesced_columns: 703 continue 704 if name in using_column_tables and table in using_column_tables[name]: 705 coalesced_columns.add(name) 706 tables = using_column_tables[name] 707 coalesce_args = [exp.column(name, table=table) for table in tables] 708 709 new_selections.append( 710 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 711 ) 712 else: 713 alias_ = renamed_columns.get(name, name) 714 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 715 new_selections.append( 716 alias(selection_expr, alias_, copy=False) 717 if alias_ != name 718 else selection_expr 719 ) 720 721 # Ensures we don't overwrite the initial selections with an empty list 722 if new_selections and isinstance(scope.expression, exp.Select): 723 scope.expression.set("expressions", new_selections) 724 725 726def _add_except_columns( 727 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 728) -> None: 729 except_ = expression.args.get("except") 730 731 if not except_: 732 return 733 734 columns = {e.name for e in except_} 735 736 for table in tables: 737 except_columns[id(table)] = columns 738 739 740def _add_rename_columns( 741 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 742) -> None: 743 rename = expression.args.get("rename") 744 745 if not rename: 746 return 747 748 columns = {e.this.name: e.alias for e in rename} 749 750 for table in tables: 751 rename_columns[id(table)] = columns 752 753 754def _add_replace_columns( 755 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 756) -> None: 757 replace = expression.args.get("replace") 758 759 if not replace: 760 return 761 762 columns = {e.alias: e for e in replace} 763 764 for table in tables: 765 replace_columns[id(table)] = columns 766 767 768def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 769 """Ensure all output columns are aliased""" 770 if isinstance(scope_or_expression, exp.Expression): 771 scope = build_scope(scope_or_expression) 772 if not isinstance(scope, Scope): 773 return 774 else: 775 scope = scope_or_expression 776 777 new_selections = [] 778 for i, (selection, aliased_column) in enumerate( 779 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 780 ): 781 if selection is None or isinstance(selection, exp.QueryTransform): 782 break 783 784 if isinstance(selection, exp.Subquery): 785 if not selection.output_name: 786 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 787 elif not isinstance(selection, exp.Alias) and not selection.is_star: 788 selection = alias( 789 selection, 790 alias=selection.output_name or f"_col_{i}", 791 copy=False, 792 ) 793 if aliased_column: 794 selection.set("alias", exp.to_identifier(aliased_column)) 795 796 new_selections.append(selection) 797 798 if new_selections and isinstance(scope.expression, exp.Select): 799 scope.expression.set("expressions", new_selections) 800 801 802def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 803 """Makes sure all identifiers that need to be quoted are quoted.""" 804 return expression.transform( 805 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 806 ) # type: ignore 807 808 809def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 810 """ 811 Pushes down the CTE alias columns into the projection, 812 813 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 814 815 Example: 816 >>> import sqlglot 817 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 818 >>> pushdown_cte_alias_columns(expression).sql() 819 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 820 821 Args: 822 expression: Expression to pushdown. 823 824 Returns: 825 The expression with the CTE aliases pushed down into the projection. 826 """ 827 for cte in expression.find_all(exp.CTE): 828 if cte.alias_column_names: 829 new_expressions = [] 830 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 831 if isinstance(projection, exp.Alias): 832 projection.set("alias", _alias) 833 else: 834 projection = alias(projection, alias=_alias) 835 new_expressions.append(projection) 836 cte.this.set("expressions", new_expressions) 837 838 return expression 839 840 841class Resolver: 842 """ 843 Helper for resolving columns. 844 845 This is a class so we can lazily load some things and easily share them across functions. 846 """ 847 848 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 849 self.scope = scope 850 self.schema = schema 851 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 852 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 853 self._all_columns: t.Optional[t.Set[str]] = None 854 self._infer_schema = infer_schema 855 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 856 857 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 858 """ 859 Get the table for a column name. 860 861 Args: 862 column_name: The column name to find the table for. 863 Returns: 864 The table name if it can be found/inferred. 865 """ 866 if self._unambiguous_columns is None: 867 self._unambiguous_columns = self._get_unambiguous_columns( 868 self._get_all_source_columns() 869 ) 870 871 table_name = self._unambiguous_columns.get(column_name) 872 873 if not table_name and self._infer_schema: 874 sources_without_schema = tuple( 875 source 876 for source, columns in self._get_all_source_columns().items() 877 if not columns or "*" in columns 878 ) 879 if len(sources_without_schema) == 1: 880 table_name = sources_without_schema[0] 881 882 if table_name not in self.scope.selected_sources: 883 return exp.to_identifier(table_name) 884 885 node, _ = self.scope.selected_sources.get(table_name) 886 887 if isinstance(node, exp.Query): 888 while node and node.alias != table_name: 889 node = node.parent 890 891 node_alias = node.args.get("alias") 892 if node_alias: 893 return exp.to_identifier(node_alias.this) 894 895 return exp.to_identifier(table_name) 896 897 @property 898 def all_columns(self) -> t.Set[str]: 899 """All available columns of all sources in this scope""" 900 if self._all_columns is None: 901 self._all_columns = { 902 column for columns in self._get_all_source_columns().values() for column in columns 903 } 904 return self._all_columns 905 906 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 907 """Resolve the source columns for a given source `name`.""" 908 cache_key = (name, only_visible) 909 if cache_key not in self._get_source_columns_cache: 910 if name not in self.scope.sources: 911 raise OptimizeError(f"Unknown table: {name}") 912 913 source = self.scope.sources[name] 914 915 if isinstance(source, exp.Table): 916 columns = self.schema.column_names(source, only_visible) 917 elif isinstance(source, Scope) and isinstance( 918 source.expression, (exp.Values, exp.Unnest) 919 ): 920 columns = source.expression.named_selects 921 922 # in bigquery, unnest structs are automatically scoped as tables, so you can 923 # directly select a struct field in a query. 924 # this handles the case where the unnest is statically defined. 925 if self.schema.dialect == "bigquery": 926 if source.expression.is_type(exp.DataType.Type.STRUCT): 927 for k in source.expression.type.expressions: # type: ignore 928 columns.append(k.name) 929 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 930 set_op = source.expression 931 932 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 933 on_column_list = set_op.args.get("on") 934 935 if on_column_list: 936 # The resulting columns are the columns in the ON clause: 937 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 938 columns = [col.name for col in on_column_list] 939 elif set_op.side or set_op.kind: 940 side = set_op.side 941 kind = set_op.kind 942 943 left = set_op.left.named_selects 944 right = set_op.right.named_selects 945 946 # We use dict.fromkeys to deduplicate keys and maintain insertion order 947 if side == "LEFT": 948 columns = left 949 elif side == "FULL": 950 columns = list(dict.fromkeys(left + right)) 951 elif kind == "INNER": 952 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 953 else: 954 columns = set_op.named_selects 955 else: 956 select = seq_get(source.expression.selects, 0) 957 958 if isinstance(select, exp.QueryTransform): 959 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 960 schema = select.args.get("schema") 961 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 962 else: 963 columns = source.expression.named_selects 964 965 node, _ = self.scope.selected_sources.get(name) or (None, None) 966 if isinstance(node, Scope): 967 column_aliases = node.expression.alias_column_names 968 elif isinstance(node, exp.Expression): 969 column_aliases = node.alias_column_names 970 else: 971 column_aliases = [] 972 973 if column_aliases: 974 # If the source's columns are aliased, their aliases shadow the corresponding column names. 975 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 976 columns = [ 977 alias or name 978 for (name, alias) in itertools.zip_longest(columns, column_aliases) 979 ] 980 981 self._get_source_columns_cache[cache_key] = columns 982 983 return self._get_source_columns_cache[cache_key] 984 985 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 986 if self._source_columns is None: 987 self._source_columns = { 988 source_name: self.get_source_columns(source_name) 989 for source_name, source in itertools.chain( 990 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 991 ) 992 } 993 return self._source_columns 994 995 def _get_unambiguous_columns( 996 self, source_columns: t.Dict[str, t.Sequence[str]] 997 ) -> t.Mapping[str, str]: 998 """ 999 Find all the unambiguous columns in sources. 1000 1001 Args: 1002 source_columns: Mapping of names to source columns. 1003 1004 Returns: 1005 Mapping of column name to source name. 1006 """ 1007 if not source_columns: 1008 return {} 1009 1010 source_columns_pairs = list(source_columns.items()) 1011 1012 first_table, first_columns = source_columns_pairs[0] 1013 1014 if len(source_columns_pairs) == 1: 1015 # Performance optimization - avoid copying first_columns if there is only one table. 1016 return SingleValuedMapping(first_columns, first_table) 1017 1018 unambiguous_columns = {col: first_table for col in first_columns} 1019 all_columns = set(unambiguous_columns) 1020 1021 for table, columns in source_columns_pairs[1:]: 1022 unique = set(columns) 1023 ambiguous = all_columns.intersection(unique) 1024 all_columns.update(columns) 1025 1026 for column in ambiguous: 1027 unambiguous_columns.pop(column, None) 1028 for column in unique.difference(ambiguous): 1029 unambiguous_columns[column] = table 1030 1031 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27 dialect: DialectType = None, 28) -> exp.Expression: 29 """ 30 Rewrite sqlglot AST to have fully qualified columns. 31 32 Example: 33 >>> import sqlglot 34 >>> schema = {"tbl": {"col": "INT"}} 35 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 36 >>> qualify_columns(expression, schema).sql() 37 'SELECT tbl.col AS col FROM tbl' 38 39 Args: 40 expression: Expression to qualify. 41 schema: Database schema. 42 expand_alias_refs: Whether to expand references to aliases. 43 expand_stars: Whether to expand star queries. This is a necessary step 44 for most of the optimizer's rules to work; do not set to False unless you 45 know what you're doing! 46 infer_schema: Whether to infer the schema if missing. 47 allow_partial_qualification: Whether to allow partial qualification. 48 49 Returns: 50 The qualified expression. 51 52 Notes: 53 - Currently only handles a single PIVOT or UNPIVOT operator 54 """ 55 schema = ensure_schema(schema, dialect=dialect) 56 annotator = TypeAnnotator(schema) 57 infer_schema = schema.empty if infer_schema is None else infer_schema 58 dialect = Dialect.get_or_raise(schema.dialect) 59 pseudocolumns = dialect.PSEUDOCOLUMNS 60 bigquery = dialect == "bigquery" 61 62 for scope in traverse_scope(expression): 63 scope_expression = scope.expression 64 is_select = isinstance(scope_expression, exp.Select) 65 66 if is_select and scope_expression.args.get("connect"): 67 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 68 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 69 scope_expression.transform( 70 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 71 copy=False, 72 ) 73 scope.clear_cache() 74 75 resolver = Resolver(scope, schema, infer_schema=infer_schema) 76 _pop_table_column_aliases(scope.ctes) 77 _pop_table_column_aliases(scope.derived_tables) 78 using_column_tables = _expand_using(scope, resolver) 79 80 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 81 _expand_alias_refs( 82 scope, 83 resolver, 84 dialect, 85 expand_only_groupby=bigquery, 86 ) 87 88 _convert_columns_to_dots(scope, resolver) 89 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if bigquery: 112 annotator.annotate_scope(scope) 113 114 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 128 129 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 130 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 131 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 132 # this list here to ensure those in the former category will be excluded. 133 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 134 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 135 136 all_unqualified_columns.extend(unqualified_columns) 137 138 if all_unqualified_columns: 139 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 140 141 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
769def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 770 """Ensure all output columns are aliased""" 771 if isinstance(scope_or_expression, exp.Expression): 772 scope = build_scope(scope_or_expression) 773 if not isinstance(scope, Scope): 774 return 775 else: 776 scope = scope_or_expression 777 778 new_selections = [] 779 for i, (selection, aliased_column) in enumerate( 780 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 781 ): 782 if selection is None or isinstance(selection, exp.QueryTransform): 783 break 784 785 if isinstance(selection, exp.Subquery): 786 if not selection.output_name: 787 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 788 elif not isinstance(selection, exp.Alias) and not selection.is_star: 789 selection = alias( 790 selection, 791 alias=selection.output_name or f"_col_{i}", 792 copy=False, 793 ) 794 if aliased_column: 795 selection.set("alias", exp.to_identifier(aliased_column)) 796 797 new_selections.append(selection) 798 799 if new_selections and isinstance(scope.expression, exp.Select): 800 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
803def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 804 """Makes sure all identifiers that need to be quoted are quoted.""" 805 return expression.transform( 806 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 807 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
810def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 811 """ 812 Pushes down the CTE alias columns into the projection, 813 814 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 815 816 Example: 817 >>> import sqlglot 818 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 819 >>> pushdown_cte_alias_columns(expression).sql() 820 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 821 822 Args: 823 expression: Expression to pushdown. 824 825 Returns: 826 The expression with the CTE aliases pushed down into the projection. 827 """ 828 for cte in expression.find_all(exp.CTE): 829 if cte.alias_column_names: 830 new_expressions = [] 831 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 832 if isinstance(projection, exp.Alias): 833 projection.set("alias", _alias) 834 else: 835 projection = alias(projection, alias=_alias) 836 new_expressions.append(projection) 837 cte.this.set("expressions", new_expressions) 838 839 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
842class Resolver: 843 """ 844 Helper for resolving columns. 845 846 This is a class so we can lazily load some things and easily share them across functions. 847 """ 848 849 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 850 self.scope = scope 851 self.schema = schema 852 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 853 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 854 self._all_columns: t.Optional[t.Set[str]] = None 855 self._infer_schema = infer_schema 856 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 857 858 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 859 """ 860 Get the table for a column name. 861 862 Args: 863 column_name: The column name to find the table for. 864 Returns: 865 The table name if it can be found/inferred. 866 """ 867 if self._unambiguous_columns is None: 868 self._unambiguous_columns = self._get_unambiguous_columns( 869 self._get_all_source_columns() 870 ) 871 872 table_name = self._unambiguous_columns.get(column_name) 873 874 if not table_name and self._infer_schema: 875 sources_without_schema = tuple( 876 source 877 for source, columns in self._get_all_source_columns().items() 878 if not columns or "*" in columns 879 ) 880 if len(sources_without_schema) == 1: 881 table_name = sources_without_schema[0] 882 883 if table_name not in self.scope.selected_sources: 884 return exp.to_identifier(table_name) 885 886 node, _ = self.scope.selected_sources.get(table_name) 887 888 if isinstance(node, exp.Query): 889 while node and node.alias != table_name: 890 node = node.parent 891 892 node_alias = node.args.get("alias") 893 if node_alias: 894 return exp.to_identifier(node_alias.this) 895 896 return exp.to_identifier(table_name) 897 898 @property 899 def all_columns(self) -> t.Set[str]: 900 """All available columns of all sources in this scope""" 901 if self._all_columns is None: 902 self._all_columns = { 903 column for columns in self._get_all_source_columns().values() for column in columns 904 } 905 return self._all_columns 906 907 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 908 """Resolve the source columns for a given source `name`.""" 909 cache_key = (name, only_visible) 910 if cache_key not in self._get_source_columns_cache: 911 if name not in self.scope.sources: 912 raise OptimizeError(f"Unknown table: {name}") 913 914 source = self.scope.sources[name] 915 916 if isinstance(source, exp.Table): 917 columns = self.schema.column_names(source, only_visible) 918 elif isinstance(source, Scope) and isinstance( 919 source.expression, (exp.Values, exp.Unnest) 920 ): 921 columns = source.expression.named_selects 922 923 # in bigquery, unnest structs are automatically scoped as tables, so you can 924 # directly select a struct field in a query. 925 # this handles the case where the unnest is statically defined. 926 if self.schema.dialect == "bigquery": 927 if source.expression.is_type(exp.DataType.Type.STRUCT): 928 for k in source.expression.type.expressions: # type: ignore 929 columns.append(k.name) 930 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 931 set_op = source.expression 932 933 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 934 on_column_list = set_op.args.get("on") 935 936 if on_column_list: 937 # The resulting columns are the columns in the ON clause: 938 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 939 columns = [col.name for col in on_column_list] 940 elif set_op.side or set_op.kind: 941 side = set_op.side 942 kind = set_op.kind 943 944 left = set_op.left.named_selects 945 right = set_op.right.named_selects 946 947 # We use dict.fromkeys to deduplicate keys and maintain insertion order 948 if side == "LEFT": 949 columns = left 950 elif side == "FULL": 951 columns = list(dict.fromkeys(left + right)) 952 elif kind == "INNER": 953 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 954 else: 955 columns = set_op.named_selects 956 else: 957 select = seq_get(source.expression.selects, 0) 958 959 if isinstance(select, exp.QueryTransform): 960 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 961 schema = select.args.get("schema") 962 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 963 else: 964 columns = source.expression.named_selects 965 966 node, _ = self.scope.selected_sources.get(name) or (None, None) 967 if isinstance(node, Scope): 968 column_aliases = node.expression.alias_column_names 969 elif isinstance(node, exp.Expression): 970 column_aliases = node.alias_column_names 971 else: 972 column_aliases = [] 973 974 if column_aliases: 975 # If the source's columns are aliased, their aliases shadow the corresponding column names. 976 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 977 columns = [ 978 alias or name 979 for (name, alias) in itertools.zip_longest(columns, column_aliases) 980 ] 981 982 self._get_source_columns_cache[cache_key] = columns 983 984 return self._get_source_columns_cache[cache_key] 985 986 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 987 if self._source_columns is None: 988 self._source_columns = { 989 source_name: self.get_source_columns(source_name) 990 for source_name, source in itertools.chain( 991 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 992 ) 993 } 994 return self._source_columns 995 996 def _get_unambiguous_columns( 997 self, source_columns: t.Dict[str, t.Sequence[str]] 998 ) -> t.Mapping[str, str]: 999 """ 1000 Find all the unambiguous columns in sources. 1001 1002 Args: 1003 source_columns: Mapping of names to source columns. 1004 1005 Returns: 1006 Mapping of column name to source name. 1007 """ 1008 if not source_columns: 1009 return {} 1010 1011 source_columns_pairs = list(source_columns.items()) 1012 1013 first_table, first_columns = source_columns_pairs[0] 1014 1015 if len(source_columns_pairs) == 1: 1016 # Performance optimization - avoid copying first_columns if there is only one table. 1017 return SingleValuedMapping(first_columns, first_table) 1018 1019 unambiguous_columns = {col: first_table for col in first_columns} 1020 all_columns = set(unambiguous_columns) 1021 1022 for table, columns in source_columns_pairs[1:]: 1023 unique = set(columns) 1024 ambiguous = all_columns.intersection(unique) 1025 all_columns.update(columns) 1026 1027 for column in ambiguous: 1028 unambiguous_columns.pop(column, None) 1029 for column in unique.difference(ambiguous): 1030 unambiguous_columns[column] = table 1031 1032 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
849 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 850 self.scope = scope 851 self.schema = schema 852 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 853 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 854 self._all_columns: t.Optional[t.Set[str]] = None 855 self._infer_schema = infer_schema 856 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
858 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 859 """ 860 Get the table for a column name. 861 862 Args: 863 column_name: The column name to find the table for. 864 Returns: 865 The table name if it can be found/inferred. 866 """ 867 if self._unambiguous_columns is None: 868 self._unambiguous_columns = self._get_unambiguous_columns( 869 self._get_all_source_columns() 870 ) 871 872 table_name = self._unambiguous_columns.get(column_name) 873 874 if not table_name and self._infer_schema: 875 sources_without_schema = tuple( 876 source 877 for source, columns in self._get_all_source_columns().items() 878 if not columns or "*" in columns 879 ) 880 if len(sources_without_schema) == 1: 881 table_name = sources_without_schema[0] 882 883 if table_name not in self.scope.selected_sources: 884 return exp.to_identifier(table_name) 885 886 node, _ = self.scope.selected_sources.get(table_name) 887 888 if isinstance(node, exp.Query): 889 while node and node.alias != table_name: 890 node = node.parent 891 892 node_alias = node.args.get("alias") 893 if node_alias: 894 return exp.to_identifier(node_alias.this) 895 896 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
898 @property 899 def all_columns(self) -> t.Set[str]: 900 """All available columns of all sources in this scope""" 901 if self._all_columns is None: 902 self._all_columns = { 903 column for columns in self._get_all_source_columns().values() for column in columns 904 } 905 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
907 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 908 """Resolve the source columns for a given source `name`.""" 909 cache_key = (name, only_visible) 910 if cache_key not in self._get_source_columns_cache: 911 if name not in self.scope.sources: 912 raise OptimizeError(f"Unknown table: {name}") 913 914 source = self.scope.sources[name] 915 916 if isinstance(source, exp.Table): 917 columns = self.schema.column_names(source, only_visible) 918 elif isinstance(source, Scope) and isinstance( 919 source.expression, (exp.Values, exp.Unnest) 920 ): 921 columns = source.expression.named_selects 922 923 # in bigquery, unnest structs are automatically scoped as tables, so you can 924 # directly select a struct field in a query. 925 # this handles the case where the unnest is statically defined. 926 if self.schema.dialect == "bigquery": 927 if source.expression.is_type(exp.DataType.Type.STRUCT): 928 for k in source.expression.type.expressions: # type: ignore 929 columns.append(k.name) 930 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 931 set_op = source.expression 932 933 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 934 on_column_list = set_op.args.get("on") 935 936 if on_column_list: 937 # The resulting columns are the columns in the ON clause: 938 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 939 columns = [col.name for col in on_column_list] 940 elif set_op.side or set_op.kind: 941 side = set_op.side 942 kind = set_op.kind 943 944 left = set_op.left.named_selects 945 right = set_op.right.named_selects 946 947 # We use dict.fromkeys to deduplicate keys and maintain insertion order 948 if side == "LEFT": 949 columns = left 950 elif side == "FULL": 951 columns = list(dict.fromkeys(left + right)) 952 elif kind == "INNER": 953 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 954 else: 955 columns = set_op.named_selects 956 else: 957 select = seq_get(source.expression.selects, 0) 958 959 if isinstance(select, exp.QueryTransform): 960 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 961 schema = select.args.get("schema") 962 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 963 else: 964 columns = source.expression.named_selects 965 966 node, _ = self.scope.selected_sources.get(name) or (None, None) 967 if isinstance(node, Scope): 968 column_aliases = node.expression.alias_column_names 969 elif isinstance(node, exp.Expression): 970 column_aliases = node.alias_column_names 971 else: 972 column_aliases = [] 973 974 if column_aliases: 975 # If the source's columns are aliased, their aliases shadow the corresponding column names. 976 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 977 columns = [ 978 alias or name 979 for (name, alias) in itertools.zip_longest(columns, column_aliases) 980 ] 981 982 self._get_source_columns_cache[cache_key] = columns 983 984 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.