sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26 dialect: DialectType = None, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema, dialect=dialect) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 bigquery = dialect == "bigquery" 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=bigquery, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if bigquery: 111 annotator.annotate_scope(scope) 112 113 return expression 114 115 116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression 141 142 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 144 name_columns = [ 145 field.this 146 for field in unpivot.fields 147 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 148 ] 149 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 150 151 return itertools.chain(name_columns, value_columns) 152 153 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 155 """ 156 Remove table column aliases. 157 158 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 159 """ 160 for derived_table in derived_tables: 161 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 162 continue 163 table_alias = derived_table.args.get("alias") 164 if table_alias: 165 table_alias.args.pop("columns", None) 166 167 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 169 columns = {} 170 171 def _update_source_columns(source_name: str) -> None: 172 for column_name in resolver.get_source_columns(source_name): 173 if column_name not in columns: 174 columns[column_name] = source_name 175 176 joins = list(scope.find_all(exp.Join)) 177 names = {join.alias_or_name for join in joins} 178 ordered = [key for key in scope.selected_sources if key not in names] 179 180 if names and not ordered: 181 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 182 183 # Mapping of automatically joined column names to an ordered set of source names (dict). 184 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 185 186 for source_name in ordered: 187 _update_source_columns(source_name) 188 189 for i, join in enumerate(joins): 190 source_table = ordered[-1] 191 if source_table: 192 _update_source_columns(source_table) 193 194 join_table = join.alias_or_name 195 ordered.append(join_table) 196 197 using = join.args.get("using") 198 if not using: 199 continue 200 201 join_columns = resolver.get_source_columns(join_table) 202 conditions = [] 203 using_identifier_count = len(using) 204 is_semi_or_anti_join = join.is_semi_or_anti_join 205 206 for identifier in using: 207 identifier = identifier.name 208 table = columns.get(identifier) 209 210 if not table or identifier not in join_columns: 211 if (columns and "*" not in columns) and join_columns: 212 raise OptimizeError(f"Cannot automatically join: {identifier}") 213 214 table = table or source_table 215 216 if i == 0 or using_identifier_count == 1: 217 lhs: exp.Expression = exp.column(identifier, table=table) 218 else: 219 coalesce_columns = [ 220 exp.column(identifier, table=t) 221 for t in ordered[:-1] 222 if identifier in resolver.get_source_columns(t) 223 ] 224 if len(coalesce_columns) > 1: 225 lhs = exp.func("coalesce", *coalesce_columns) 226 else: 227 lhs = exp.column(identifier, table=table) 228 229 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 230 231 # Set all values in the dict to None, because we only care about the key ordering 232 tables = column_tables.setdefault(identifier, {}) 233 234 # Do not update the dict if this was a SEMI/ANTI join in 235 # order to avoid generating COALESCE columns for this join pair 236 if not is_semi_or_anti_join: 237 if table not in tables: 238 tables[table] = None 239 if join_table not in tables: 240 tables[join_table] = None 241 242 join.args.pop("using") 243 join.set("on", exp.and_(*conditions, copy=False)) 244 245 if column_tables: 246 for column in scope.columns: 247 if not column.table and column.name in column_tables: 248 tables = column_tables[column.name] 249 coalesce_args = [exp.column(column.name, table=table) for table in tables] 250 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 251 252 if isinstance(column.parent, exp.Select): 253 # Ensure the USING column keeps its name if it's projected 254 replacement = alias(replacement, alias=column.name, copy=False) 255 elif isinstance(column.parent, exp.Struct): 256 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 257 replacement = exp.PropertyEQ( 258 this=exp.to_identifier(column.name), expression=replacement 259 ) 260 261 scope.replace(column, replacement) 262 263 return column_tables 264 265 266def _expand_alias_refs( 267 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 268) -> None: 269 """ 270 Expand references to aliases. 271 Example: 272 SELECT y.foo AS bar, bar * 2 AS baz FROM y 273 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 274 """ 275 expression = scope.expression 276 277 if not isinstance(expression, exp.Select) or dialect == "oracle": 278 return 279 280 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 281 projections = {s.alias_or_name for s in expression.selects} 282 283 def replace_columns( 284 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 285 ) -> None: 286 is_group_by = isinstance(node, exp.Group) 287 is_having = isinstance(node, exp.Having) 288 if not node or (expand_only_groupby and not is_group_by): 289 return 290 291 for column in walk_in_scope(node, prune=lambda node: node.is_star): 292 if not isinstance(column, exp.Column): 293 continue 294 295 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 296 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 297 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 298 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 299 if expand_only_groupby and is_group_by and column.parent is not node: 300 continue 301 302 skip_replace = False 303 table = resolver.get_table(column.name) if resolve_table and not column.table else None 304 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 305 306 if alias_expr: 307 skip_replace = bool( 308 alias_expr.find(exp.AggFunc) 309 and column.find_ancestor(exp.AggFunc) 310 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 311 ) 312 313 # BigQuery's having clause gets confused if an alias matches a source. 314 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 315 # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table 316 if is_having and dialect == "bigquery": 317 skip_replace = skip_replace or any( 318 node.parts[0].name in projections 319 for node in alias_expr.find_all(exp.Column) 320 ) 321 322 if table and (not alias_expr or skip_replace): 323 column.set("table", table) 324 elif not column.table and alias_expr and not skip_replace: 325 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 326 if literal_index: 327 column.replace(exp.Literal.number(i)) 328 else: 329 column = column.replace(exp.paren(alias_expr)) 330 simplified = simplify_parens(column, dialect) 331 if simplified is not column: 332 column.replace(simplified) 333 334 for i, projection in enumerate(expression.selects): 335 replace_columns(projection) 336 if isinstance(projection, exp.Alias): 337 alias_to_expression[projection.alias] = (projection.this, i + 1) 338 339 parent_scope = scope 340 while parent_scope.is_union: 341 parent_scope = parent_scope.parent 342 343 # We shouldn't expand aliases if they match the recursive CTE's columns 344 if parent_scope.is_cte: 345 cte = parent_scope.expression.parent 346 if cte.find_ancestor(exp.With).recursive: 347 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 348 alias_to_expression.pop(recursive_cte_column.output_name, None) 349 350 replace_columns(expression.args.get("where")) 351 replace_columns(expression.args.get("group"), literal_index=True) 352 replace_columns(expression.args.get("having"), resolve_table=True) 353 replace_columns(expression.args.get("qualify"), resolve_table=True) 354 355 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 356 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 357 if dialect == "snowflake": 358 for join in expression.args.get("joins") or []: 359 replace_columns(join) 360 361 scope.clear_cache() 362 363 364def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 365 expression = scope.expression 366 group = expression.args.get("group") 367 if not group: 368 return 369 370 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 371 expression.set("group", group) 372 373 374def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 375 for modifier_key in ("order", "distinct"): 376 modifier = scope.expression.args.get(modifier_key) 377 if isinstance(modifier, exp.Distinct): 378 modifier = modifier.args.get("on") 379 380 if not isinstance(modifier, exp.Expression): 381 continue 382 383 modifier_expressions = modifier.expressions 384 if modifier_key == "order": 385 modifier_expressions = [ordered.this for ordered in modifier_expressions] 386 387 for original, expanded in zip( 388 modifier_expressions, 389 _expand_positional_references( 390 scope, modifier_expressions, resolver.schema.dialect, alias=True 391 ), 392 ): 393 for agg in original.find_all(exp.AggFunc): 394 for col in agg.find_all(exp.Column): 395 if not col.table: 396 col.set("table", resolver.get_table(col.name)) 397 398 original.replace(expanded) 399 400 if scope.expression.args.get("group"): 401 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 402 403 for expression in modifier_expressions: 404 expression.replace( 405 exp.to_identifier(_select_by_pos(scope, expression).alias) 406 if expression.is_int 407 else selects.get(expression, expression) 408 ) 409 410 411def _expand_positional_references( 412 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 413) -> t.List[exp.Expression]: 414 new_nodes: t.List[exp.Expression] = [] 415 ambiguous_projections = None 416 417 for node in expressions: 418 if node.is_int: 419 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 420 421 if alias: 422 new_nodes.append(exp.column(select.args["alias"].copy())) 423 else: 424 select = select.this 425 426 if dialect == "bigquery": 427 if ambiguous_projections is None: 428 # When a projection name is also a source name and it is referenced in the 429 # GROUP BY clause, BQ can't understand what the identifier corresponds to 430 ambiguous_projections = { 431 s.alias_or_name 432 for s in scope.expression.selects 433 if s.alias_or_name in scope.selected_sources 434 } 435 436 ambiguous = any( 437 column.parts[0].name in ambiguous_projections 438 for column in select.find_all(exp.Column) 439 ) 440 else: 441 ambiguous = False 442 443 if ( 444 isinstance(select, exp.CONSTANTS) 445 or select.find(exp.Explode, exp.Unnest) 446 or ambiguous 447 ): 448 new_nodes.append(node) 449 else: 450 new_nodes.append(select.copy()) 451 else: 452 new_nodes.append(node) 453 454 return new_nodes 455 456 457def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 458 try: 459 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 460 except IndexError: 461 raise OptimizeError(f"Unknown output column: {node.name}") 462 463 464def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 465 """ 466 Converts `Column` instances that represent struct field lookup into chained `Dots`. 467 468 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 469 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 470 """ 471 converted = False 472 for column in itertools.chain(scope.columns, scope.stars): 473 if isinstance(column, exp.Dot): 474 continue 475 476 column_table: t.Optional[str | exp.Identifier] = column.table 477 if ( 478 column_table 479 and column_table not in scope.sources 480 and ( 481 not scope.parent 482 or column_table not in scope.parent.sources 483 or not scope.is_correlated_subquery 484 ) 485 ): 486 root, *parts = column.parts 487 488 if root.name in scope.sources: 489 # The struct is already qualified, but we still need to change the AST 490 column_table = root 491 root, *parts = parts 492 else: 493 column_table = resolver.get_table(root.name) 494 495 if column_table: 496 converted = True 497 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 498 499 if converted: 500 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 501 # a `for column in scope.columns` iteration, even though they shouldn't be 502 scope.clear_cache() 503 504 505def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 506 """Disambiguate columns, ensuring each column specifies a source""" 507 for column in scope.columns: 508 column_table = column.table 509 column_name = column.name 510 511 if column_table and column_table in scope.sources: 512 source_columns = resolver.get_source_columns(column_table) 513 if ( 514 not allow_partial_qualification 515 and source_columns 516 and column_name not in source_columns 517 and "*" not in source_columns 518 ): 519 raise OptimizeError(f"Unknown column: {column_name}") 520 521 if not column_table: 522 if scope.pivots and not column.find_ancestor(exp.Pivot): 523 # If the column is under the Pivot expression, we need to qualify it 524 # using the name of the pivoted source instead of the pivot's alias 525 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 526 continue 527 528 # column_table can be a '' because bigquery unnest has no table alias 529 column_table = resolver.get_table(column_name) 530 if column_table: 531 column.set("table", column_table) 532 elif ( 533 resolver.schema.dialect == "bigquery" 534 and len(column.parts) == 1 535 and column_name in scope.selected_sources 536 ): 537 # BigQuery allows tables to be referenced as columns, treating them as structs 538 scope.replace(column, exp.TableColumn(this=column.this)) 539 540 for pivot in scope.pivots: 541 for column in pivot.find_all(exp.Column): 542 if not column.table and column.name in resolver.all_columns: 543 column_table = resolver.get_table(column.name) 544 if column_table: 545 column.set("table", column_table) 546 547 548def _expand_struct_stars_bigquery( 549 expression: exp.Dot, 550) -> t.List[exp.Alias]: 551 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 552 553 dot_column = expression.find(exp.Column) 554 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 555 return [] 556 557 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 558 dot_column = dot_column.copy() 559 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 560 561 # First part is the table name and last part is the star so they can be dropped 562 dot_parts = expression.parts[1:-1] 563 564 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 565 for part in dot_parts[1:]: 566 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 567 # Unable to expand star unless all fields are named 568 if not isinstance(field.this, exp.Identifier): 569 return [] 570 571 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 572 starting_struct = field 573 break 574 else: 575 # There is no matching field in the struct 576 return [] 577 578 taken_names = set() 579 new_selections = [] 580 581 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 582 name = field.name 583 584 # Ambiguous or anonymous fields can't be expanded 585 if name in taken_names or not isinstance(field.this, exp.Identifier): 586 return [] 587 588 taken_names.add(name) 589 590 this = field.this.copy() 591 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 592 new_column = exp.column( 593 t.cast(exp.Identifier, root), 594 table=dot_column.args.get("table"), 595 fields=t.cast(t.List[exp.Identifier], parts), 596 ) 597 new_selections.append(alias(new_column, this, copy=False)) 598 599 return new_selections 600 601 602def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]: 603 """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column""" 604 605 # it is not (<sub_exp>).* pattern, which means we can't expand 606 if not isinstance(expression.this, exp.Paren): 607 return [] 608 609 # find column definition to get data-type 610 dot_column = expression.find(exp.Column) 611 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 612 return [] 613 614 parent = dot_column.parent 615 starting_struct = dot_column.type 616 617 # walk up AST and down into struct definition in sync 618 while parent is not None: 619 if isinstance(parent, exp.Paren): 620 parent = parent.parent 621 continue 622 623 # if parent is not a dot, then something is wrong 624 if not isinstance(parent, exp.Dot): 625 return [] 626 627 # if the rhs of the dot is star we are done 628 rhs = parent.right 629 if isinstance(rhs, exp.Star): 630 break 631 632 # if it is not identifier, then something is wrong 633 if not isinstance(rhs, exp.Identifier): 634 return [] 635 636 # Check if current rhs identifier is in struct 637 matched = False 638 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 639 if struct_field_def.name == rhs.name: 640 matched = True 641 starting_struct = struct_field_def.kind # update struct 642 break 643 644 if not matched: 645 return [] 646 647 parent = parent.parent 648 649 # build new aliases to expand star 650 new_selections = [] 651 652 # fetch the outermost parentheses for new aliaes 653 outer_paren = expression.this 654 655 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 656 new_identifier = struct_field_def.this.copy() 657 new_dot = exp.Dot.build([outer_paren.copy(), new_identifier]) 658 new_alias = alias(new_dot, new_identifier, copy=False) 659 new_selections.append(new_alias) 660 661 return new_selections 662 663 664def _expand_stars( 665 scope: Scope, 666 resolver: Resolver, 667 using_column_tables: t.Dict[str, t.Any], 668 pseudocolumns: t.Set[str], 669 annotator: TypeAnnotator, 670) -> None: 671 """Expand stars to lists of column selections""" 672 673 new_selections: t.List[exp.Expression] = [] 674 except_columns: t.Dict[int, t.Set[str]] = {} 675 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 676 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 677 678 coalesced_columns = set() 679 dialect = resolver.schema.dialect 680 681 pivot_output_columns = None 682 pivot_exclude_columns: t.Set[str] = set() 683 684 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 685 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 686 if pivot.unpivot: 687 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 688 689 for field in pivot.fields: 690 if isinstance(field, exp.In): 691 pivot_exclude_columns.update( 692 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 693 ) 694 695 else: 696 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 697 698 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 699 if not pivot_output_columns: 700 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 701 702 is_bigquery = dialect == "bigquery" 703 is_risingwave = dialect == "risingwave" 704 705 if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars): 706 # Found struct expansion, annotate scope ahead of time 707 annotator.annotate_scope(scope) 708 709 for expression in scope.expression.selects: 710 tables = [] 711 if isinstance(expression, exp.Star): 712 tables.extend(scope.selected_sources) 713 _add_except_columns(expression, tables, except_columns) 714 _add_replace_columns(expression, tables, replace_columns) 715 _add_rename_columns(expression, tables, rename_columns) 716 elif expression.is_star: 717 if not isinstance(expression, exp.Dot): 718 tables.append(expression.table) 719 _add_except_columns(expression.this, tables, except_columns) 720 _add_replace_columns(expression.this, tables, replace_columns) 721 _add_rename_columns(expression.this, tables, rename_columns) 722 elif is_bigquery: 723 struct_fields = _expand_struct_stars_bigquery(expression) 724 if struct_fields: 725 new_selections.extend(struct_fields) 726 continue 727 elif is_risingwave: 728 struct_fields = _expand_struct_stars_risingwave(expression) 729 if struct_fields: 730 new_selections.extend(struct_fields) 731 continue 732 733 if not tables: 734 new_selections.append(expression) 735 continue 736 737 for table in tables: 738 if table not in scope.sources: 739 raise OptimizeError(f"Unknown table: {table}") 740 741 columns = resolver.get_source_columns(table, only_visible=True) 742 columns = columns or scope.outer_columns 743 744 if pseudocolumns: 745 columns = [name for name in columns if name.upper() not in pseudocolumns] 746 747 if not columns or "*" in columns: 748 return 749 750 table_id = id(table) 751 columns_to_exclude = except_columns.get(table_id) or set() 752 renamed_columns = rename_columns.get(table_id, {}) 753 replaced_columns = replace_columns.get(table_id, {}) 754 755 if pivot: 756 if pivot_output_columns and pivot_exclude_columns: 757 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 758 pivot_columns.extend(pivot_output_columns) 759 else: 760 pivot_columns = pivot.alias_column_names 761 762 if pivot_columns: 763 new_selections.extend( 764 alias(exp.column(name, table=pivot.alias), name, copy=False) 765 for name in pivot_columns 766 if name not in columns_to_exclude 767 ) 768 continue 769 770 for name in columns: 771 if name in columns_to_exclude or name in coalesced_columns: 772 continue 773 if name in using_column_tables and table in using_column_tables[name]: 774 coalesced_columns.add(name) 775 tables = using_column_tables[name] 776 coalesce_args = [exp.column(name, table=table) for table in tables] 777 778 new_selections.append( 779 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 780 ) 781 else: 782 alias_ = renamed_columns.get(name, name) 783 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 784 new_selections.append( 785 alias(selection_expr, alias_, copy=False) 786 if alias_ != name 787 else selection_expr 788 ) 789 790 # Ensures we don't overwrite the initial selections with an empty list 791 if new_selections and isinstance(scope.expression, exp.Select): 792 scope.expression.set("expressions", new_selections) 793 794 795def _add_except_columns( 796 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 797) -> None: 798 except_ = expression.args.get("except") 799 800 if not except_: 801 return 802 803 columns = {e.name for e in except_} 804 805 for table in tables: 806 except_columns[id(table)] = columns 807 808 809def _add_rename_columns( 810 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 811) -> None: 812 rename = expression.args.get("rename") 813 814 if not rename: 815 return 816 817 columns = {e.this.name: e.alias for e in rename} 818 819 for table in tables: 820 rename_columns[id(table)] = columns 821 822 823def _add_replace_columns( 824 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 825) -> None: 826 replace = expression.args.get("replace") 827 828 if not replace: 829 return 830 831 columns = {e.alias: e for e in replace} 832 833 for table in tables: 834 replace_columns[id(table)] = columns 835 836 837def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 838 """Ensure all output columns are aliased""" 839 if isinstance(scope_or_expression, exp.Expression): 840 scope = build_scope(scope_or_expression) 841 if not isinstance(scope, Scope): 842 return 843 else: 844 scope = scope_or_expression 845 846 new_selections = [] 847 for i, (selection, aliased_column) in enumerate( 848 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 849 ): 850 if selection is None or isinstance(selection, exp.QueryTransform): 851 break 852 853 if isinstance(selection, exp.Subquery): 854 if not selection.output_name: 855 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 856 elif not isinstance(selection, exp.Alias) and not selection.is_star: 857 selection = alias( 858 selection, 859 alias=selection.output_name or f"_col_{i}", 860 copy=False, 861 ) 862 if aliased_column: 863 selection.set("alias", exp.to_identifier(aliased_column)) 864 865 new_selections.append(selection) 866 867 if new_selections and isinstance(scope.expression, exp.Select): 868 scope.expression.set("expressions", new_selections) 869 870 871def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 872 """Makes sure all identifiers that need to be quoted are quoted.""" 873 return expression.transform( 874 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 875 ) # type: ignore 876 877 878def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 879 """ 880 Pushes down the CTE alias columns into the projection, 881 882 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 883 884 Example: 885 >>> import sqlglot 886 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 887 >>> pushdown_cte_alias_columns(expression).sql() 888 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 889 890 Args: 891 expression: Expression to pushdown. 892 893 Returns: 894 The expression with the CTE aliases pushed down into the projection. 895 """ 896 for cte in expression.find_all(exp.CTE): 897 if cte.alias_column_names: 898 new_expressions = [] 899 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 900 if isinstance(projection, exp.Alias): 901 projection.set("alias", _alias) 902 else: 903 projection = alias(projection, alias=_alias) 904 new_expressions.append(projection) 905 cte.this.set("expressions", new_expressions) 906 907 return expression 908 909 910class Resolver: 911 """ 912 Helper for resolving columns. 913 914 This is a class so we can lazily load some things and easily share them across functions. 915 """ 916 917 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 918 self.scope = scope 919 self.schema = schema 920 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 921 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 922 self._all_columns: t.Optional[t.Set[str]] = None 923 self._infer_schema = infer_schema 924 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 925 926 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 927 """ 928 Get the table for a column name. 929 930 Args: 931 column_name: The column name to find the table for. 932 Returns: 933 The table name if it can be found/inferred. 934 """ 935 if self._unambiguous_columns is None: 936 self._unambiguous_columns = self._get_unambiguous_columns( 937 self._get_all_source_columns() 938 ) 939 940 table_name = self._unambiguous_columns.get(column_name) 941 942 if not table_name and self._infer_schema: 943 sources_without_schema = tuple( 944 source 945 for source, columns in self._get_all_source_columns().items() 946 if not columns or "*" in columns 947 ) 948 if len(sources_without_schema) == 1: 949 table_name = sources_without_schema[0] 950 951 if table_name not in self.scope.selected_sources: 952 return exp.to_identifier(table_name) 953 954 node, _ = self.scope.selected_sources.get(table_name) 955 956 if isinstance(node, exp.Query): 957 while node and node.alias != table_name: 958 node = node.parent 959 960 node_alias = node.args.get("alias") 961 if node_alias: 962 return exp.to_identifier(node_alias.this) 963 964 return exp.to_identifier(table_name) 965 966 @property 967 def all_columns(self) -> t.Set[str]: 968 """All available columns of all sources in this scope""" 969 if self._all_columns is None: 970 self._all_columns = { 971 column for columns in self._get_all_source_columns().values() for column in columns 972 } 973 return self._all_columns 974 975 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 976 """Resolve the source columns for a given source `name`.""" 977 cache_key = (name, only_visible) 978 if cache_key not in self._get_source_columns_cache: 979 if name not in self.scope.sources: 980 raise OptimizeError(f"Unknown table: {name}") 981 982 source = self.scope.sources[name] 983 984 if isinstance(source, exp.Table): 985 columns = self.schema.column_names(source, only_visible) 986 elif isinstance(source, Scope) and isinstance( 987 source.expression, (exp.Values, exp.Unnest) 988 ): 989 columns = source.expression.named_selects 990 991 # in bigquery, unnest structs are automatically scoped as tables, so you can 992 # directly select a struct field in a query. 993 # this handles the case where the unnest is statically defined. 994 if self.schema.dialect == "bigquery": 995 if source.expression.is_type(exp.DataType.Type.STRUCT): 996 for k in source.expression.type.expressions: # type: ignore 997 columns.append(k.name) 998 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 999 set_op = source.expression 1000 1001 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1002 on_column_list = set_op.args.get("on") 1003 1004 if on_column_list: 1005 # The resulting columns are the columns in the ON clause: 1006 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1007 columns = [col.name for col in on_column_list] 1008 elif set_op.side or set_op.kind: 1009 side = set_op.side 1010 kind = set_op.kind 1011 1012 left = set_op.left.named_selects 1013 right = set_op.right.named_selects 1014 1015 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1016 if side == "LEFT": 1017 columns = left 1018 elif side == "FULL": 1019 columns = list(dict.fromkeys(left + right)) 1020 elif kind == "INNER": 1021 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1022 else: 1023 columns = set_op.named_selects 1024 else: 1025 select = seq_get(source.expression.selects, 0) 1026 1027 if isinstance(select, exp.QueryTransform): 1028 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1029 schema = select.args.get("schema") 1030 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1031 else: 1032 columns = source.expression.named_selects 1033 1034 node, _ = self.scope.selected_sources.get(name) or (None, None) 1035 if isinstance(node, Scope): 1036 column_aliases = node.expression.alias_column_names 1037 elif isinstance(node, exp.Expression): 1038 column_aliases = node.alias_column_names 1039 else: 1040 column_aliases = [] 1041 1042 if column_aliases: 1043 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1044 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1045 columns = [ 1046 alias or name 1047 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1048 ] 1049 1050 self._get_source_columns_cache[cache_key] = columns 1051 1052 return self._get_source_columns_cache[cache_key] 1053 1054 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 1055 if self._source_columns is None: 1056 self._source_columns = { 1057 source_name: self.get_source_columns(source_name) 1058 for source_name, source in itertools.chain( 1059 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 1060 ) 1061 } 1062 return self._source_columns 1063 1064 def _get_unambiguous_columns( 1065 self, source_columns: t.Dict[str, t.Sequence[str]] 1066 ) -> t.Mapping[str, str]: 1067 """ 1068 Find all the unambiguous columns in sources. 1069 1070 Args: 1071 source_columns: Mapping of names to source columns. 1072 1073 Returns: 1074 Mapping of column name to source name. 1075 """ 1076 if not source_columns: 1077 return {} 1078 1079 source_columns_pairs = list(source_columns.items()) 1080 1081 first_table, first_columns = source_columns_pairs[0] 1082 1083 if len(source_columns_pairs) == 1: 1084 # Performance optimization - avoid copying first_columns if there is only one table. 1085 return SingleValuedMapping(first_columns, first_table) 1086 1087 unambiguous_columns = {col: first_table for col in first_columns} 1088 all_columns = set(unambiguous_columns) 1089 1090 for table, columns in source_columns_pairs[1:]: 1091 unique = set(columns) 1092 ambiguous = all_columns.intersection(unique) 1093 all_columns.update(columns) 1094 1095 for column in ambiguous: 1096 unambiguous_columns.pop(column, None) 1097 for column in unique.difference(ambiguous): 1098 unambiguous_columns[column] = table 1099 1100 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27 dialect: DialectType = None, 28) -> exp.Expression: 29 """ 30 Rewrite sqlglot AST to have fully qualified columns. 31 32 Example: 33 >>> import sqlglot 34 >>> schema = {"tbl": {"col": "INT"}} 35 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 36 >>> qualify_columns(expression, schema).sql() 37 'SELECT tbl.col AS col FROM tbl' 38 39 Args: 40 expression: Expression to qualify. 41 schema: Database schema. 42 expand_alias_refs: Whether to expand references to aliases. 43 expand_stars: Whether to expand star queries. This is a necessary step 44 for most of the optimizer's rules to work; do not set to False unless you 45 know what you're doing! 46 infer_schema: Whether to infer the schema if missing. 47 allow_partial_qualification: Whether to allow partial qualification. 48 49 Returns: 50 The qualified expression. 51 52 Notes: 53 - Currently only handles a single PIVOT or UNPIVOT operator 54 """ 55 schema = ensure_schema(schema, dialect=dialect) 56 annotator = TypeAnnotator(schema) 57 infer_schema = schema.empty if infer_schema is None else infer_schema 58 dialect = Dialect.get_or_raise(schema.dialect) 59 pseudocolumns = dialect.PSEUDOCOLUMNS 60 bigquery = dialect == "bigquery" 61 62 for scope in traverse_scope(expression): 63 scope_expression = scope.expression 64 is_select = isinstance(scope_expression, exp.Select) 65 66 if is_select and scope_expression.args.get("connect"): 67 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 68 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 69 scope_expression.transform( 70 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 71 copy=False, 72 ) 73 scope.clear_cache() 74 75 resolver = Resolver(scope, schema, infer_schema=infer_schema) 76 _pop_table_column_aliases(scope.ctes) 77 _pop_table_column_aliases(scope.derived_tables) 78 using_column_tables = _expand_using(scope, resolver) 79 80 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 81 _expand_alias_refs( 82 scope, 83 resolver, 84 dialect, 85 expand_only_groupby=bigquery, 86 ) 87 88 _convert_columns_to_dots(scope, resolver) 89 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if bigquery: 112 annotator.annotate_scope(scope) 113 114 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 128 129 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 130 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 131 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 132 # this list here to ensure those in the former category will be excluded. 133 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 134 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 135 136 all_unqualified_columns.extend(unqualified_columns) 137 138 if all_unqualified_columns: 139 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 140 141 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
838def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 839 """Ensure all output columns are aliased""" 840 if isinstance(scope_or_expression, exp.Expression): 841 scope = build_scope(scope_or_expression) 842 if not isinstance(scope, Scope): 843 return 844 else: 845 scope = scope_or_expression 846 847 new_selections = [] 848 for i, (selection, aliased_column) in enumerate( 849 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 850 ): 851 if selection is None or isinstance(selection, exp.QueryTransform): 852 break 853 854 if isinstance(selection, exp.Subquery): 855 if not selection.output_name: 856 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 857 elif not isinstance(selection, exp.Alias) and not selection.is_star: 858 selection = alias( 859 selection, 860 alias=selection.output_name or f"_col_{i}", 861 copy=False, 862 ) 863 if aliased_column: 864 selection.set("alias", exp.to_identifier(aliased_column)) 865 866 new_selections.append(selection) 867 868 if new_selections and isinstance(scope.expression, exp.Select): 869 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
872def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 873 """Makes sure all identifiers that need to be quoted are quoted.""" 874 return expression.transform( 875 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 876 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
879def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 880 """ 881 Pushes down the CTE alias columns into the projection, 882 883 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 884 885 Example: 886 >>> import sqlglot 887 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 888 >>> pushdown_cte_alias_columns(expression).sql() 889 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 890 891 Args: 892 expression: Expression to pushdown. 893 894 Returns: 895 The expression with the CTE aliases pushed down into the projection. 896 """ 897 for cte in expression.find_all(exp.CTE): 898 if cte.alias_column_names: 899 new_expressions = [] 900 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 901 if isinstance(projection, exp.Alias): 902 projection.set("alias", _alias) 903 else: 904 projection = alias(projection, alias=_alias) 905 new_expressions.append(projection) 906 cte.this.set("expressions", new_expressions) 907 908 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
911class Resolver: 912 """ 913 Helper for resolving columns. 914 915 This is a class so we can lazily load some things and easily share them across functions. 916 """ 917 918 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 919 self.scope = scope 920 self.schema = schema 921 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 922 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 923 self._all_columns: t.Optional[t.Set[str]] = None 924 self._infer_schema = infer_schema 925 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 926 927 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 928 """ 929 Get the table for a column name. 930 931 Args: 932 column_name: The column name to find the table for. 933 Returns: 934 The table name if it can be found/inferred. 935 """ 936 if self._unambiguous_columns is None: 937 self._unambiguous_columns = self._get_unambiguous_columns( 938 self._get_all_source_columns() 939 ) 940 941 table_name = self._unambiguous_columns.get(column_name) 942 943 if not table_name and self._infer_schema: 944 sources_without_schema = tuple( 945 source 946 for source, columns in self._get_all_source_columns().items() 947 if not columns or "*" in columns 948 ) 949 if len(sources_without_schema) == 1: 950 table_name = sources_without_schema[0] 951 952 if table_name not in self.scope.selected_sources: 953 return exp.to_identifier(table_name) 954 955 node, _ = self.scope.selected_sources.get(table_name) 956 957 if isinstance(node, exp.Query): 958 while node and node.alias != table_name: 959 node = node.parent 960 961 node_alias = node.args.get("alias") 962 if node_alias: 963 return exp.to_identifier(node_alias.this) 964 965 return exp.to_identifier(table_name) 966 967 @property 968 def all_columns(self) -> t.Set[str]: 969 """All available columns of all sources in this scope""" 970 if self._all_columns is None: 971 self._all_columns = { 972 column for columns in self._get_all_source_columns().values() for column in columns 973 } 974 return self._all_columns 975 976 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 977 """Resolve the source columns for a given source `name`.""" 978 cache_key = (name, only_visible) 979 if cache_key not in self._get_source_columns_cache: 980 if name not in self.scope.sources: 981 raise OptimizeError(f"Unknown table: {name}") 982 983 source = self.scope.sources[name] 984 985 if isinstance(source, exp.Table): 986 columns = self.schema.column_names(source, only_visible) 987 elif isinstance(source, Scope) and isinstance( 988 source.expression, (exp.Values, exp.Unnest) 989 ): 990 columns = source.expression.named_selects 991 992 # in bigquery, unnest structs are automatically scoped as tables, so you can 993 # directly select a struct field in a query. 994 # this handles the case where the unnest is statically defined. 995 if self.schema.dialect == "bigquery": 996 if source.expression.is_type(exp.DataType.Type.STRUCT): 997 for k in source.expression.type.expressions: # type: ignore 998 columns.append(k.name) 999 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1000 set_op = source.expression 1001 1002 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1003 on_column_list = set_op.args.get("on") 1004 1005 if on_column_list: 1006 # The resulting columns are the columns in the ON clause: 1007 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1008 columns = [col.name for col in on_column_list] 1009 elif set_op.side or set_op.kind: 1010 side = set_op.side 1011 kind = set_op.kind 1012 1013 left = set_op.left.named_selects 1014 right = set_op.right.named_selects 1015 1016 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1017 if side == "LEFT": 1018 columns = left 1019 elif side == "FULL": 1020 columns = list(dict.fromkeys(left + right)) 1021 elif kind == "INNER": 1022 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1023 else: 1024 columns = set_op.named_selects 1025 else: 1026 select = seq_get(source.expression.selects, 0) 1027 1028 if isinstance(select, exp.QueryTransform): 1029 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1030 schema = select.args.get("schema") 1031 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1032 else: 1033 columns = source.expression.named_selects 1034 1035 node, _ = self.scope.selected_sources.get(name) or (None, None) 1036 if isinstance(node, Scope): 1037 column_aliases = node.expression.alias_column_names 1038 elif isinstance(node, exp.Expression): 1039 column_aliases = node.alias_column_names 1040 else: 1041 column_aliases = [] 1042 1043 if column_aliases: 1044 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1045 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1046 columns = [ 1047 alias or name 1048 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1049 ] 1050 1051 self._get_source_columns_cache[cache_key] = columns 1052 1053 return self._get_source_columns_cache[cache_key] 1054 1055 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 1056 if self._source_columns is None: 1057 self._source_columns = { 1058 source_name: self.get_source_columns(source_name) 1059 for source_name, source in itertools.chain( 1060 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 1061 ) 1062 } 1063 return self._source_columns 1064 1065 def _get_unambiguous_columns( 1066 self, source_columns: t.Dict[str, t.Sequence[str]] 1067 ) -> t.Mapping[str, str]: 1068 """ 1069 Find all the unambiguous columns in sources. 1070 1071 Args: 1072 source_columns: Mapping of names to source columns. 1073 1074 Returns: 1075 Mapping of column name to source name. 1076 """ 1077 if not source_columns: 1078 return {} 1079 1080 source_columns_pairs = list(source_columns.items()) 1081 1082 first_table, first_columns = source_columns_pairs[0] 1083 1084 if len(source_columns_pairs) == 1: 1085 # Performance optimization - avoid copying first_columns if there is only one table. 1086 return SingleValuedMapping(first_columns, first_table) 1087 1088 unambiguous_columns = {col: first_table for col in first_columns} 1089 all_columns = set(unambiguous_columns) 1090 1091 for table, columns in source_columns_pairs[1:]: 1092 unique = set(columns) 1093 ambiguous = all_columns.intersection(unique) 1094 all_columns.update(columns) 1095 1096 for column in ambiguous: 1097 unambiguous_columns.pop(column, None) 1098 for column in unique.difference(ambiguous): 1099 unambiguous_columns[column] = table 1100 1101 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
918 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 919 self.scope = scope 920 self.schema = schema 921 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 922 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 923 self._all_columns: t.Optional[t.Set[str]] = None 924 self._infer_schema = infer_schema 925 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
927 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 928 """ 929 Get the table for a column name. 930 931 Args: 932 column_name: The column name to find the table for. 933 Returns: 934 The table name if it can be found/inferred. 935 """ 936 if self._unambiguous_columns is None: 937 self._unambiguous_columns = self._get_unambiguous_columns( 938 self._get_all_source_columns() 939 ) 940 941 table_name = self._unambiguous_columns.get(column_name) 942 943 if not table_name and self._infer_schema: 944 sources_without_schema = tuple( 945 source 946 for source, columns in self._get_all_source_columns().items() 947 if not columns or "*" in columns 948 ) 949 if len(sources_without_schema) == 1: 950 table_name = sources_without_schema[0] 951 952 if table_name not in self.scope.selected_sources: 953 return exp.to_identifier(table_name) 954 955 node, _ = self.scope.selected_sources.get(table_name) 956 957 if isinstance(node, exp.Query): 958 while node and node.alias != table_name: 959 node = node.parent 960 961 node_alias = node.args.get("alias") 962 if node_alias: 963 return exp.to_identifier(node_alias.this) 964 965 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
967 @property 968 def all_columns(self) -> t.Set[str]: 969 """All available columns of all sources in this scope""" 970 if self._all_columns is None: 971 self._all_columns = { 972 column for columns in self._get_all_source_columns().values() for column in columns 973 } 974 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
976 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 977 """Resolve the source columns for a given source `name`.""" 978 cache_key = (name, only_visible) 979 if cache_key not in self._get_source_columns_cache: 980 if name not in self.scope.sources: 981 raise OptimizeError(f"Unknown table: {name}") 982 983 source = self.scope.sources[name] 984 985 if isinstance(source, exp.Table): 986 columns = self.schema.column_names(source, only_visible) 987 elif isinstance(source, Scope) and isinstance( 988 source.expression, (exp.Values, exp.Unnest) 989 ): 990 columns = source.expression.named_selects 991 992 # in bigquery, unnest structs are automatically scoped as tables, so you can 993 # directly select a struct field in a query. 994 # this handles the case where the unnest is statically defined. 995 if self.schema.dialect == "bigquery": 996 if source.expression.is_type(exp.DataType.Type.STRUCT): 997 for k in source.expression.type.expressions: # type: ignore 998 columns.append(k.name) 999 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1000 set_op = source.expression 1001 1002 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1003 on_column_list = set_op.args.get("on") 1004 1005 if on_column_list: 1006 # The resulting columns are the columns in the ON clause: 1007 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1008 columns = [col.name for col in on_column_list] 1009 elif set_op.side or set_op.kind: 1010 side = set_op.side 1011 kind = set_op.kind 1012 1013 left = set_op.left.named_selects 1014 right = set_op.right.named_selects 1015 1016 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1017 if side == "LEFT": 1018 columns = left 1019 elif side == "FULL": 1020 columns = list(dict.fromkeys(left + right)) 1021 elif kind == "INNER": 1022 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1023 else: 1024 columns = set_op.named_selects 1025 else: 1026 select = seq_get(source.expression.selects, 0) 1027 1028 if isinstance(select, exp.QueryTransform): 1029 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1030 schema = select.args.get("schema") 1031 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1032 else: 1033 columns = source.expression.named_selects 1034 1035 node, _ = self.scope.selected_sources.get(name) or (None, None) 1036 if isinstance(node, Scope): 1037 column_aliases = node.expression.alias_column_names 1038 elif isinstance(node, exp.Expression): 1039 column_aliases = node.alias_column_names 1040 else: 1041 column_aliases = [] 1042 1043 if column_aliases: 1044 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1045 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1046 columns = [ 1047 alias or name 1048 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1049 ] 1050 1051 self._get_source_columns_cache[cache_key] = columns 1052 1053 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.