sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26 dialect: DialectType = None, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema, dialect=dialect) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 bigquery = dialect == "bigquery" 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=bigquery, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if bigquery: 111 annotator.annotate_scope(scope) 112 113 return expression 114 115 116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression 141 142 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 144 name_columns = [ 145 field.this 146 for field in unpivot.fields 147 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 148 ] 149 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 150 151 return itertools.chain(name_columns, value_columns) 152 153 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 155 """ 156 Remove table column aliases. 157 158 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 159 """ 160 for derived_table in derived_tables: 161 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 162 continue 163 table_alias = derived_table.args.get("alias") 164 if table_alias: 165 table_alias.args.pop("columns", None) 166 167 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 169 columns = {} 170 171 def _update_source_columns(source_name: str) -> None: 172 for column_name in resolver.get_source_columns(source_name): 173 if column_name not in columns: 174 columns[column_name] = source_name 175 176 joins = list(scope.find_all(exp.Join)) 177 names = {join.alias_or_name for join in joins} 178 ordered = [key for key in scope.selected_sources if key not in names] 179 180 if names and not ordered: 181 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 182 183 # Mapping of automatically joined column names to an ordered set of source names (dict). 184 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 185 186 for source_name in ordered: 187 _update_source_columns(source_name) 188 189 for i, join in enumerate(joins): 190 source_table = ordered[-1] 191 if source_table: 192 _update_source_columns(source_table) 193 194 join_table = join.alias_or_name 195 ordered.append(join_table) 196 197 using = join.args.get("using") 198 if not using: 199 continue 200 201 join_columns = resolver.get_source_columns(join_table) 202 conditions = [] 203 using_identifier_count = len(using) 204 is_semi_or_anti_join = join.is_semi_or_anti_join 205 206 for identifier in using: 207 identifier = identifier.name 208 table = columns.get(identifier) 209 210 if not table or identifier not in join_columns: 211 if (columns and "*" not in columns) and join_columns: 212 raise OptimizeError(f"Cannot automatically join: {identifier}") 213 214 table = table or source_table 215 216 if i == 0 or using_identifier_count == 1: 217 lhs: exp.Expression = exp.column(identifier, table=table) 218 else: 219 coalesce_columns = [ 220 exp.column(identifier, table=t) 221 for t in ordered[:-1] 222 if identifier in resolver.get_source_columns(t) 223 ] 224 if len(coalesce_columns) > 1: 225 lhs = exp.func("coalesce", *coalesce_columns) 226 else: 227 lhs = exp.column(identifier, table=table) 228 229 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 230 231 # Set all values in the dict to None, because we only care about the key ordering 232 tables = column_tables.setdefault(identifier, {}) 233 234 # Do not update the dict if this was a SEMI/ANTI join in 235 # order to avoid generating COALESCE columns for this join pair 236 if not is_semi_or_anti_join: 237 if table not in tables: 238 tables[table] = None 239 if join_table not in tables: 240 tables[join_table] = None 241 242 join.args.pop("using") 243 join.set("on", exp.and_(*conditions, copy=False)) 244 245 if column_tables: 246 for column in scope.columns: 247 if not column.table and column.name in column_tables: 248 tables = column_tables[column.name] 249 coalesce_args = [exp.column(column.name, table=table) for table in tables] 250 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 251 252 if isinstance(column.parent, exp.Select): 253 # Ensure the USING column keeps its name if it's projected 254 replacement = alias(replacement, alias=column.name, copy=False) 255 elif isinstance(column.parent, exp.Struct): 256 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 257 replacement = exp.PropertyEQ( 258 this=exp.to_identifier(column.name), expression=replacement 259 ) 260 261 scope.replace(column, replacement) 262 263 return column_tables 264 265 266def _expand_alias_refs( 267 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 268) -> None: 269 """ 270 Expand references to aliases. 271 Example: 272 SELECT y.foo AS bar, bar * 2 AS baz FROM y 273 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 274 """ 275 expression = scope.expression 276 277 if not isinstance(expression, exp.Select) or dialect == "oracle": 278 return 279 280 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 281 projections = {s.alias_or_name for s in expression.selects} 282 is_bigquery = dialect == "bigquery" 283 284 def replace_columns( 285 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 286 ) -> None: 287 is_group_by = isinstance(node, exp.Group) 288 is_having = isinstance(node, exp.Having) 289 if not node or (expand_only_groupby and not is_group_by): 290 return 291 292 for column in walk_in_scope(node, prune=lambda node: node.is_star): 293 if not isinstance(column, exp.Column): 294 continue 295 296 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 297 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 298 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 299 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 300 if expand_only_groupby and is_group_by and column.parent is not node: 301 continue 302 303 skip_replace = False 304 table = resolver.get_table(column.name) if resolve_table and not column.table else None 305 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 306 307 if alias_expr: 308 skip_replace = bool( 309 alias_expr.find(exp.AggFunc) 310 and column.find_ancestor(exp.AggFunc) 311 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 312 ) 313 314 # BigQuery's having clause gets confused if an alias matches a source. 315 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 316 # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b) 317 # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed" 318 if is_having and is_bigquery: 319 skip_replace = skip_replace or any( 320 node.parts[0].name in projections 321 for node in alias_expr.find_all(exp.Column) 322 ) 323 elif is_bigquery and (is_group_by or is_having): 324 column_table = table.name if table else column.table 325 if column_table in projections: 326 # BigQuery's GROUP BY and HAVING clauses get confused if the column name 327 # matches a source name and a projection. For instance: 328 # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1 329 # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table 330 # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause" 331 column.replace(exp.to_identifier(column.name)) 332 return 333 334 if table and (not alias_expr or skip_replace): 335 column.set("table", table) 336 elif not column.table and alias_expr and not skip_replace: 337 if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and ( 338 literal_index or resolve_table 339 ): 340 if literal_index: 341 column.replace(exp.Literal.number(i)) 342 else: 343 column = column.replace(exp.paren(alias_expr)) 344 simplified = simplify_parens(column, dialect) 345 if simplified is not column: 346 column.replace(simplified) 347 348 for i, projection in enumerate(expression.selects): 349 replace_columns(projection) 350 if isinstance(projection, exp.Alias): 351 alias_to_expression[projection.alias] = (projection.this, i + 1) 352 353 parent_scope = scope 354 while parent_scope.is_union: 355 parent_scope = parent_scope.parent 356 357 # We shouldn't expand aliases if they match the recursive CTE's columns 358 if parent_scope.is_cte: 359 cte = parent_scope.expression.parent 360 if cte.find_ancestor(exp.With).recursive: 361 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 362 alias_to_expression.pop(recursive_cte_column.output_name, None) 363 364 replace_columns(expression.args.get("where")) 365 replace_columns(expression.args.get("group"), literal_index=True) 366 replace_columns(expression.args.get("having"), resolve_table=True) 367 replace_columns(expression.args.get("qualify"), resolve_table=True) 368 369 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 370 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 371 if dialect == "snowflake": 372 for join in expression.args.get("joins") or []: 373 replace_columns(join) 374 375 scope.clear_cache() 376 377 378def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 379 expression = scope.expression 380 group = expression.args.get("group") 381 if not group: 382 return 383 384 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 385 expression.set("group", group) 386 387 388def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 389 for modifier_key in ("order", "distinct"): 390 modifier = scope.expression.args.get(modifier_key) 391 if isinstance(modifier, exp.Distinct): 392 modifier = modifier.args.get("on") 393 394 if not isinstance(modifier, exp.Expression): 395 continue 396 397 modifier_expressions = modifier.expressions 398 if modifier_key == "order": 399 modifier_expressions = [ordered.this for ordered in modifier_expressions] 400 401 for original, expanded in zip( 402 modifier_expressions, 403 _expand_positional_references( 404 scope, modifier_expressions, resolver.schema.dialect, alias=True 405 ), 406 ): 407 for agg in original.find_all(exp.AggFunc): 408 for col in agg.find_all(exp.Column): 409 if not col.table: 410 col.set("table", resolver.get_table(col.name)) 411 412 original.replace(expanded) 413 414 if scope.expression.args.get("group"): 415 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 416 417 for expression in modifier_expressions: 418 expression.replace( 419 exp.to_identifier(_select_by_pos(scope, expression).alias) 420 if expression.is_int 421 else selects.get(expression, expression) 422 ) 423 424 425def _expand_positional_references( 426 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 427) -> t.List[exp.Expression]: 428 new_nodes: t.List[exp.Expression] = [] 429 ambiguous_projections = None 430 431 for node in expressions: 432 if node.is_int: 433 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 434 435 if alias: 436 new_nodes.append(exp.column(select.args["alias"].copy())) 437 else: 438 select = select.this 439 440 if dialect == "bigquery": 441 if ambiguous_projections is None: 442 # When a projection name is also a source name and it is referenced in the 443 # GROUP BY clause, BQ can't understand what the identifier corresponds to 444 ambiguous_projections = { 445 s.alias_or_name 446 for s in scope.expression.selects 447 if s.alias_or_name in scope.selected_sources 448 } 449 450 ambiguous = any( 451 column.parts[0].name in ambiguous_projections 452 for column in select.find_all(exp.Column) 453 ) 454 else: 455 ambiguous = False 456 457 if ( 458 isinstance(select, exp.CONSTANTS) 459 or select.is_number 460 or select.find(exp.Explode, exp.Unnest) 461 or ambiguous 462 ): 463 new_nodes.append(node) 464 else: 465 new_nodes.append(select.copy()) 466 else: 467 new_nodes.append(node) 468 469 return new_nodes 470 471 472def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 473 try: 474 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 475 except IndexError: 476 raise OptimizeError(f"Unknown output column: {node.name}") 477 478 479def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 480 """ 481 Converts `Column` instances that represent struct field lookup into chained `Dots`. 482 483 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 484 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 485 """ 486 converted = False 487 for column in itertools.chain(scope.columns, scope.stars): 488 if isinstance(column, exp.Dot): 489 continue 490 491 column_table: t.Optional[str | exp.Identifier] = column.table 492 if ( 493 column_table 494 and column_table not in scope.sources 495 and ( 496 not scope.parent 497 or column_table not in scope.parent.sources 498 or not scope.is_correlated_subquery 499 ) 500 ): 501 root, *parts = column.parts 502 503 if root.name in scope.sources: 504 # The struct is already qualified, but we still need to change the AST 505 column_table = root 506 root, *parts = parts 507 else: 508 column_table = resolver.get_table(root.name) 509 510 if column_table: 511 converted = True 512 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 513 514 if converted: 515 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 516 # a `for column in scope.columns` iteration, even though they shouldn't be 517 scope.clear_cache() 518 519 520def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 521 """Disambiguate columns, ensuring each column specifies a source""" 522 for column in scope.columns: 523 column_table = column.table 524 column_name = column.name 525 526 if column_table and column_table in scope.sources: 527 source_columns = resolver.get_source_columns(column_table) 528 if ( 529 not allow_partial_qualification 530 and source_columns 531 and column_name not in source_columns 532 and "*" not in source_columns 533 ): 534 raise OptimizeError(f"Unknown column: {column_name}") 535 536 if not column_table: 537 if scope.pivots and not column.find_ancestor(exp.Pivot): 538 # If the column is under the Pivot expression, we need to qualify it 539 # using the name of the pivoted source instead of the pivot's alias 540 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 541 continue 542 543 # column_table can be a '' because bigquery unnest has no table alias 544 column_table = resolver.get_table(column_name) 545 if column_table: 546 column.set("table", column_table) 547 elif ( 548 resolver.schema.dialect == "bigquery" 549 and len(column.parts) == 1 550 and column_name in scope.selected_sources 551 ): 552 # BigQuery allows tables to be referenced as columns, treating them as structs 553 scope.replace(column, exp.TableColumn(this=column.this)) 554 555 for pivot in scope.pivots: 556 for column in pivot.find_all(exp.Column): 557 if not column.table and column.name in resolver.all_columns: 558 column_table = resolver.get_table(column.name) 559 if column_table: 560 column.set("table", column_table) 561 562 563def _expand_struct_stars_bigquery( 564 expression: exp.Dot, 565) -> t.List[exp.Alias]: 566 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 567 568 dot_column = expression.find(exp.Column) 569 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 570 return [] 571 572 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 573 dot_column = dot_column.copy() 574 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 575 576 # First part is the table name and last part is the star so they can be dropped 577 dot_parts = expression.parts[1:-1] 578 579 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 580 for part in dot_parts[1:]: 581 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 582 # Unable to expand star unless all fields are named 583 if not isinstance(field.this, exp.Identifier): 584 return [] 585 586 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 587 starting_struct = field 588 break 589 else: 590 # There is no matching field in the struct 591 return [] 592 593 taken_names = set() 594 new_selections = [] 595 596 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 597 name = field.name 598 599 # Ambiguous or anonymous fields can't be expanded 600 if name in taken_names or not isinstance(field.this, exp.Identifier): 601 return [] 602 603 taken_names.add(name) 604 605 this = field.this.copy() 606 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 607 new_column = exp.column( 608 t.cast(exp.Identifier, root), 609 table=dot_column.args.get("table"), 610 fields=t.cast(t.List[exp.Identifier], parts), 611 ) 612 new_selections.append(alias(new_column, this, copy=False)) 613 614 return new_selections 615 616 617def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]: 618 """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column""" 619 620 # it is not (<sub_exp>).* pattern, which means we can't expand 621 if not isinstance(expression.this, exp.Paren): 622 return [] 623 624 # find column definition to get data-type 625 dot_column = expression.find(exp.Column) 626 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 627 return [] 628 629 parent = dot_column.parent 630 starting_struct = dot_column.type 631 632 # walk up AST and down into struct definition in sync 633 while parent is not None: 634 if isinstance(parent, exp.Paren): 635 parent = parent.parent 636 continue 637 638 # if parent is not a dot, then something is wrong 639 if not isinstance(parent, exp.Dot): 640 return [] 641 642 # if the rhs of the dot is star we are done 643 rhs = parent.right 644 if isinstance(rhs, exp.Star): 645 break 646 647 # if it is not identifier, then something is wrong 648 if not isinstance(rhs, exp.Identifier): 649 return [] 650 651 # Check if current rhs identifier is in struct 652 matched = False 653 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 654 if struct_field_def.name == rhs.name: 655 matched = True 656 starting_struct = struct_field_def.kind # update struct 657 break 658 659 if not matched: 660 return [] 661 662 parent = parent.parent 663 664 # build new aliases to expand star 665 new_selections = [] 666 667 # fetch the outermost parentheses for new aliaes 668 outer_paren = expression.this 669 670 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 671 new_identifier = struct_field_def.this.copy() 672 new_dot = exp.Dot.build([outer_paren.copy(), new_identifier]) 673 new_alias = alias(new_dot, new_identifier, copy=False) 674 new_selections.append(new_alias) 675 676 return new_selections 677 678 679def _expand_stars( 680 scope: Scope, 681 resolver: Resolver, 682 using_column_tables: t.Dict[str, t.Any], 683 pseudocolumns: t.Set[str], 684 annotator: TypeAnnotator, 685) -> None: 686 """Expand stars to lists of column selections""" 687 688 new_selections: t.List[exp.Expression] = [] 689 except_columns: t.Dict[int, t.Set[str]] = {} 690 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 691 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 692 693 coalesced_columns = set() 694 dialect = resolver.schema.dialect 695 696 pivot_output_columns = None 697 pivot_exclude_columns: t.Set[str] = set() 698 699 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 700 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 701 if pivot.unpivot: 702 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 703 704 for field in pivot.fields: 705 if isinstance(field, exp.In): 706 pivot_exclude_columns.update( 707 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 708 ) 709 710 else: 711 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 712 713 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 714 if not pivot_output_columns: 715 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 716 717 is_bigquery = dialect == "bigquery" 718 is_risingwave = dialect == "risingwave" 719 720 if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars): 721 # Found struct expansion, annotate scope ahead of time 722 annotator.annotate_scope(scope) 723 724 for expression in scope.expression.selects: 725 tables = [] 726 if isinstance(expression, exp.Star): 727 tables.extend(scope.selected_sources) 728 _add_except_columns(expression, tables, except_columns) 729 _add_replace_columns(expression, tables, replace_columns) 730 _add_rename_columns(expression, tables, rename_columns) 731 elif expression.is_star: 732 if not isinstance(expression, exp.Dot): 733 tables.append(expression.table) 734 _add_except_columns(expression.this, tables, except_columns) 735 _add_replace_columns(expression.this, tables, replace_columns) 736 _add_rename_columns(expression.this, tables, rename_columns) 737 elif is_bigquery: 738 struct_fields = _expand_struct_stars_bigquery(expression) 739 if struct_fields: 740 new_selections.extend(struct_fields) 741 continue 742 elif is_risingwave: 743 struct_fields = _expand_struct_stars_risingwave(expression) 744 if struct_fields: 745 new_selections.extend(struct_fields) 746 continue 747 748 if not tables: 749 new_selections.append(expression) 750 continue 751 752 for table in tables: 753 if table not in scope.sources: 754 raise OptimizeError(f"Unknown table: {table}") 755 756 columns = resolver.get_source_columns(table, only_visible=True) 757 columns = columns or scope.outer_columns 758 759 if pseudocolumns: 760 columns = [name for name in columns if name.upper() not in pseudocolumns] 761 762 if not columns or "*" in columns: 763 return 764 765 table_id = id(table) 766 columns_to_exclude = except_columns.get(table_id) or set() 767 renamed_columns = rename_columns.get(table_id, {}) 768 replaced_columns = replace_columns.get(table_id, {}) 769 770 if pivot: 771 if pivot_output_columns and pivot_exclude_columns: 772 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 773 pivot_columns.extend(pivot_output_columns) 774 else: 775 pivot_columns = pivot.alias_column_names 776 777 if pivot_columns: 778 new_selections.extend( 779 alias(exp.column(name, table=pivot.alias), name, copy=False) 780 for name in pivot_columns 781 if name not in columns_to_exclude 782 ) 783 continue 784 785 for name in columns: 786 if name in columns_to_exclude or name in coalesced_columns: 787 continue 788 if name in using_column_tables and table in using_column_tables[name]: 789 coalesced_columns.add(name) 790 tables = using_column_tables[name] 791 coalesce_args = [exp.column(name, table=table) for table in tables] 792 793 new_selections.append( 794 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 795 ) 796 else: 797 alias_ = renamed_columns.get(name, name) 798 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 799 new_selections.append( 800 alias(selection_expr, alias_, copy=False) 801 if alias_ != name 802 else selection_expr 803 ) 804 805 # Ensures we don't overwrite the initial selections with an empty list 806 if new_selections and isinstance(scope.expression, exp.Select): 807 scope.expression.set("expressions", new_selections) 808 809 810def _add_except_columns( 811 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 812) -> None: 813 except_ = expression.args.get("except") 814 815 if not except_: 816 return 817 818 columns = {e.name for e in except_} 819 820 for table in tables: 821 except_columns[id(table)] = columns 822 823 824def _add_rename_columns( 825 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 826) -> None: 827 rename = expression.args.get("rename") 828 829 if not rename: 830 return 831 832 columns = {e.this.name: e.alias for e in rename} 833 834 for table in tables: 835 rename_columns[id(table)] = columns 836 837 838def _add_replace_columns( 839 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 840) -> None: 841 replace = expression.args.get("replace") 842 843 if not replace: 844 return 845 846 columns = {e.alias: e for e in replace} 847 848 for table in tables: 849 replace_columns[id(table)] = columns 850 851 852def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 853 """Ensure all output columns are aliased""" 854 if isinstance(scope_or_expression, exp.Expression): 855 scope = build_scope(scope_or_expression) 856 if not isinstance(scope, Scope): 857 return 858 else: 859 scope = scope_or_expression 860 861 new_selections = [] 862 for i, (selection, aliased_column) in enumerate( 863 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 864 ): 865 if selection is None or isinstance(selection, exp.QueryTransform): 866 break 867 868 if isinstance(selection, exp.Subquery): 869 if not selection.output_name: 870 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 871 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 872 selection = alias( 873 selection, 874 alias=selection.output_name or f"_col_{i}", 875 copy=False, 876 ) 877 if aliased_column: 878 selection.set("alias", exp.to_identifier(aliased_column)) 879 880 new_selections.append(selection) 881 882 if new_selections and isinstance(scope.expression, exp.Select): 883 scope.expression.set("expressions", new_selections) 884 885 886def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 887 """Makes sure all identifiers that need to be quoted are quoted.""" 888 return expression.transform( 889 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 890 ) # type: ignore 891 892 893def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 894 """ 895 Pushes down the CTE alias columns into the projection, 896 897 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 898 899 Example: 900 >>> import sqlglot 901 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 902 >>> pushdown_cte_alias_columns(expression).sql() 903 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 904 905 Args: 906 expression: Expression to pushdown. 907 908 Returns: 909 The expression with the CTE aliases pushed down into the projection. 910 """ 911 for cte in expression.find_all(exp.CTE): 912 if cte.alias_column_names: 913 new_expressions = [] 914 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 915 if isinstance(projection, exp.Alias): 916 projection.set("alias", _alias) 917 else: 918 projection = alias(projection, alias=_alias) 919 new_expressions.append(projection) 920 cte.this.set("expressions", new_expressions) 921 922 return expression 923 924 925class Resolver: 926 """ 927 Helper for resolving columns. 928 929 This is a class so we can lazily load some things and easily share them across functions. 930 """ 931 932 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 933 self.scope = scope 934 self.schema = schema 935 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 936 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 937 self._all_columns: t.Optional[t.Set[str]] = None 938 self._infer_schema = infer_schema 939 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 940 941 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 942 """ 943 Get the table for a column name. 944 945 Args: 946 column_name: The column name to find the table for. 947 Returns: 948 The table name if it can be found/inferred. 949 """ 950 if self._unambiguous_columns is None: 951 self._unambiguous_columns = self._get_unambiguous_columns( 952 self._get_all_source_columns() 953 ) 954 955 table_name = self._unambiguous_columns.get(column_name) 956 957 if not table_name and self._infer_schema: 958 sources_without_schema = tuple( 959 source 960 for source, columns in self._get_all_source_columns().items() 961 if not columns or "*" in columns 962 ) 963 if len(sources_without_schema) == 1: 964 table_name = sources_without_schema[0] 965 966 if table_name not in self.scope.selected_sources: 967 return exp.to_identifier(table_name) 968 969 node, _ = self.scope.selected_sources.get(table_name) 970 971 if isinstance(node, exp.Query): 972 while node and node.alias != table_name: 973 node = node.parent 974 975 node_alias = node.args.get("alias") 976 if node_alias: 977 return exp.to_identifier(node_alias.this) 978 979 return exp.to_identifier(table_name) 980 981 @property 982 def all_columns(self) -> t.Set[str]: 983 """All available columns of all sources in this scope""" 984 if self._all_columns is None: 985 self._all_columns = { 986 column for columns in self._get_all_source_columns().values() for column in columns 987 } 988 return self._all_columns 989 990 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 991 if isinstance(expression, exp.Select): 992 return expression.named_selects 993 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 994 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 995 return self.get_source_columns_from_set_op(expression.this) 996 if not isinstance(expression, exp.SetOperation): 997 raise OptimizeError(f"Unknown set operation: {expression}") 998 999 set_op = expression 1000 1001 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1002 on_column_list = set_op.args.get("on") 1003 1004 if on_column_list: 1005 # The resulting columns are the columns in the ON clause: 1006 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1007 columns = [col.name for col in on_column_list] 1008 elif set_op.side or set_op.kind: 1009 side = set_op.side 1010 kind = set_op.kind 1011 1012 # Visit the children UNIONs (if any) in a post-order traversal 1013 left = self.get_source_columns_from_set_op(set_op.left) 1014 right = self.get_source_columns_from_set_op(set_op.right) 1015 1016 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1017 if side == "LEFT": 1018 columns = left 1019 elif side == "FULL": 1020 columns = list(dict.fromkeys(left + right)) 1021 elif kind == "INNER": 1022 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1023 else: 1024 columns = set_op.named_selects 1025 1026 return columns 1027 1028 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 1029 """Resolve the source columns for a given source `name`.""" 1030 cache_key = (name, only_visible) 1031 if cache_key not in self._get_source_columns_cache: 1032 if name not in self.scope.sources: 1033 raise OptimizeError(f"Unknown table: {name}") 1034 1035 source = self.scope.sources[name] 1036 1037 if isinstance(source, exp.Table): 1038 columns = self.schema.column_names(source, only_visible) 1039 elif isinstance(source, Scope) and isinstance( 1040 source.expression, (exp.Values, exp.Unnest) 1041 ): 1042 columns = source.expression.named_selects 1043 1044 # in bigquery, unnest structs are automatically scoped as tables, so you can 1045 # directly select a struct field in a query. 1046 # this handles the case where the unnest is statically defined. 1047 if self.schema.dialect == "bigquery": 1048 if source.expression.is_type(exp.DataType.Type.STRUCT): 1049 for k in source.expression.type.expressions: # type: ignore 1050 columns.append(k.name) 1051 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1052 columns = self.get_source_columns_from_set_op(source.expression) 1053 1054 else: 1055 select = seq_get(source.expression.selects, 0) 1056 1057 if isinstance(select, exp.QueryTransform): 1058 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1059 schema = select.args.get("schema") 1060 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1061 else: 1062 columns = source.expression.named_selects 1063 1064 node, _ = self.scope.selected_sources.get(name) or (None, None) 1065 if isinstance(node, Scope): 1066 column_aliases = node.expression.alias_column_names 1067 elif isinstance(node, exp.Expression): 1068 column_aliases = node.alias_column_names 1069 else: 1070 column_aliases = [] 1071 1072 if column_aliases: 1073 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1074 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1075 columns = [ 1076 alias or name 1077 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1078 ] 1079 1080 self._get_source_columns_cache[cache_key] = columns 1081 1082 return self._get_source_columns_cache[cache_key] 1083 1084 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 1085 if self._source_columns is None: 1086 self._source_columns = { 1087 source_name: self.get_source_columns(source_name) 1088 for source_name, source in itertools.chain( 1089 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 1090 ) 1091 } 1092 return self._source_columns 1093 1094 def _get_unambiguous_columns( 1095 self, source_columns: t.Dict[str, t.Sequence[str]] 1096 ) -> t.Mapping[str, str]: 1097 """ 1098 Find all the unambiguous columns in sources. 1099 1100 Args: 1101 source_columns: Mapping of names to source columns. 1102 1103 Returns: 1104 Mapping of column name to source name. 1105 """ 1106 if not source_columns: 1107 return {} 1108 1109 source_columns_pairs = list(source_columns.items()) 1110 1111 first_table, first_columns = source_columns_pairs[0] 1112 1113 if len(source_columns_pairs) == 1: 1114 # Performance optimization - avoid copying first_columns if there is only one table. 1115 return SingleValuedMapping(first_columns, first_table) 1116 1117 unambiguous_columns = {col: first_table for col in first_columns} 1118 all_columns = set(unambiguous_columns) 1119 1120 for table, columns in source_columns_pairs[1:]: 1121 unique = set(columns) 1122 ambiguous = all_columns.intersection(unique) 1123 all_columns.update(columns) 1124 1125 for column in ambiguous: 1126 unambiguous_columns.pop(column, None) 1127 for column in unique.difference(ambiguous): 1128 unambiguous_columns[column] = table 1129 1130 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27 dialect: DialectType = None, 28) -> exp.Expression: 29 """ 30 Rewrite sqlglot AST to have fully qualified columns. 31 32 Example: 33 >>> import sqlglot 34 >>> schema = {"tbl": {"col": "INT"}} 35 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 36 >>> qualify_columns(expression, schema).sql() 37 'SELECT tbl.col AS col FROM tbl' 38 39 Args: 40 expression: Expression to qualify. 41 schema: Database schema. 42 expand_alias_refs: Whether to expand references to aliases. 43 expand_stars: Whether to expand star queries. This is a necessary step 44 for most of the optimizer's rules to work; do not set to False unless you 45 know what you're doing! 46 infer_schema: Whether to infer the schema if missing. 47 allow_partial_qualification: Whether to allow partial qualification. 48 49 Returns: 50 The qualified expression. 51 52 Notes: 53 - Currently only handles a single PIVOT or UNPIVOT operator 54 """ 55 schema = ensure_schema(schema, dialect=dialect) 56 annotator = TypeAnnotator(schema) 57 infer_schema = schema.empty if infer_schema is None else infer_schema 58 dialect = Dialect.get_or_raise(schema.dialect) 59 pseudocolumns = dialect.PSEUDOCOLUMNS 60 bigquery = dialect == "bigquery" 61 62 for scope in traverse_scope(expression): 63 scope_expression = scope.expression 64 is_select = isinstance(scope_expression, exp.Select) 65 66 if is_select and scope_expression.args.get("connect"): 67 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 68 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 69 scope_expression.transform( 70 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 71 copy=False, 72 ) 73 scope.clear_cache() 74 75 resolver = Resolver(scope, schema, infer_schema=infer_schema) 76 _pop_table_column_aliases(scope.ctes) 77 _pop_table_column_aliases(scope.derived_tables) 78 using_column_tables = _expand_using(scope, resolver) 79 80 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 81 _expand_alias_refs( 82 scope, 83 resolver, 84 dialect, 85 expand_only_groupby=bigquery, 86 ) 87 88 _convert_columns_to_dots(scope, resolver) 89 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if bigquery: 112 annotator.annotate_scope(scope) 113 114 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 128 129 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 130 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 131 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 132 # this list here to ensure those in the former category will be excluded. 133 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 134 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 135 136 all_unqualified_columns.extend(unqualified_columns) 137 138 if all_unqualified_columns: 139 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 140 141 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
853def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 854 """Ensure all output columns are aliased""" 855 if isinstance(scope_or_expression, exp.Expression): 856 scope = build_scope(scope_or_expression) 857 if not isinstance(scope, Scope): 858 return 859 else: 860 scope = scope_or_expression 861 862 new_selections = [] 863 for i, (selection, aliased_column) in enumerate( 864 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 865 ): 866 if selection is None or isinstance(selection, exp.QueryTransform): 867 break 868 869 if isinstance(selection, exp.Subquery): 870 if not selection.output_name: 871 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 872 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 873 selection = alias( 874 selection, 875 alias=selection.output_name or f"_col_{i}", 876 copy=False, 877 ) 878 if aliased_column: 879 selection.set("alias", exp.to_identifier(aliased_column)) 880 881 new_selections.append(selection) 882 883 if new_selections and isinstance(scope.expression, exp.Select): 884 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
887def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 888 """Makes sure all identifiers that need to be quoted are quoted.""" 889 return expression.transform( 890 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 891 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
894def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 895 """ 896 Pushes down the CTE alias columns into the projection, 897 898 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 899 900 Example: 901 >>> import sqlglot 902 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 903 >>> pushdown_cte_alias_columns(expression).sql() 904 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 905 906 Args: 907 expression: Expression to pushdown. 908 909 Returns: 910 The expression with the CTE aliases pushed down into the projection. 911 """ 912 for cte in expression.find_all(exp.CTE): 913 if cte.alias_column_names: 914 new_expressions = [] 915 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 916 if isinstance(projection, exp.Alias): 917 projection.set("alias", _alias) 918 else: 919 projection = alias(projection, alias=_alias) 920 new_expressions.append(projection) 921 cte.this.set("expressions", new_expressions) 922 923 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
926class Resolver: 927 """ 928 Helper for resolving columns. 929 930 This is a class so we can lazily load some things and easily share them across functions. 931 """ 932 933 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 934 self.scope = scope 935 self.schema = schema 936 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 937 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 938 self._all_columns: t.Optional[t.Set[str]] = None 939 self._infer_schema = infer_schema 940 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 941 942 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 943 """ 944 Get the table for a column name. 945 946 Args: 947 column_name: The column name to find the table for. 948 Returns: 949 The table name if it can be found/inferred. 950 """ 951 if self._unambiguous_columns is None: 952 self._unambiguous_columns = self._get_unambiguous_columns( 953 self._get_all_source_columns() 954 ) 955 956 table_name = self._unambiguous_columns.get(column_name) 957 958 if not table_name and self._infer_schema: 959 sources_without_schema = tuple( 960 source 961 for source, columns in self._get_all_source_columns().items() 962 if not columns or "*" in columns 963 ) 964 if len(sources_without_schema) == 1: 965 table_name = sources_without_schema[0] 966 967 if table_name not in self.scope.selected_sources: 968 return exp.to_identifier(table_name) 969 970 node, _ = self.scope.selected_sources.get(table_name) 971 972 if isinstance(node, exp.Query): 973 while node and node.alias != table_name: 974 node = node.parent 975 976 node_alias = node.args.get("alias") 977 if node_alias: 978 return exp.to_identifier(node_alias.this) 979 980 return exp.to_identifier(table_name) 981 982 @property 983 def all_columns(self) -> t.Set[str]: 984 """All available columns of all sources in this scope""" 985 if self._all_columns is None: 986 self._all_columns = { 987 column for columns in self._get_all_source_columns().values() for column in columns 988 } 989 return self._all_columns 990 991 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 992 if isinstance(expression, exp.Select): 993 return expression.named_selects 994 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 995 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 996 return self.get_source_columns_from_set_op(expression.this) 997 if not isinstance(expression, exp.SetOperation): 998 raise OptimizeError(f"Unknown set operation: {expression}") 999 1000 set_op = expression 1001 1002 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1003 on_column_list = set_op.args.get("on") 1004 1005 if on_column_list: 1006 # The resulting columns are the columns in the ON clause: 1007 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1008 columns = [col.name for col in on_column_list] 1009 elif set_op.side or set_op.kind: 1010 side = set_op.side 1011 kind = set_op.kind 1012 1013 # Visit the children UNIONs (if any) in a post-order traversal 1014 left = self.get_source_columns_from_set_op(set_op.left) 1015 right = self.get_source_columns_from_set_op(set_op.right) 1016 1017 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1018 if side == "LEFT": 1019 columns = left 1020 elif side == "FULL": 1021 columns = list(dict.fromkeys(left + right)) 1022 elif kind == "INNER": 1023 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1024 else: 1025 columns = set_op.named_selects 1026 1027 return columns 1028 1029 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 1030 """Resolve the source columns for a given source `name`.""" 1031 cache_key = (name, only_visible) 1032 if cache_key not in self._get_source_columns_cache: 1033 if name not in self.scope.sources: 1034 raise OptimizeError(f"Unknown table: {name}") 1035 1036 source = self.scope.sources[name] 1037 1038 if isinstance(source, exp.Table): 1039 columns = self.schema.column_names(source, only_visible) 1040 elif isinstance(source, Scope) and isinstance( 1041 source.expression, (exp.Values, exp.Unnest) 1042 ): 1043 columns = source.expression.named_selects 1044 1045 # in bigquery, unnest structs are automatically scoped as tables, so you can 1046 # directly select a struct field in a query. 1047 # this handles the case where the unnest is statically defined. 1048 if self.schema.dialect == "bigquery": 1049 if source.expression.is_type(exp.DataType.Type.STRUCT): 1050 for k in source.expression.type.expressions: # type: ignore 1051 columns.append(k.name) 1052 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1053 columns = self.get_source_columns_from_set_op(source.expression) 1054 1055 else: 1056 select = seq_get(source.expression.selects, 0) 1057 1058 if isinstance(select, exp.QueryTransform): 1059 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1060 schema = select.args.get("schema") 1061 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1062 else: 1063 columns = source.expression.named_selects 1064 1065 node, _ = self.scope.selected_sources.get(name) or (None, None) 1066 if isinstance(node, Scope): 1067 column_aliases = node.expression.alias_column_names 1068 elif isinstance(node, exp.Expression): 1069 column_aliases = node.alias_column_names 1070 else: 1071 column_aliases = [] 1072 1073 if column_aliases: 1074 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1075 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1076 columns = [ 1077 alias or name 1078 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1079 ] 1080 1081 self._get_source_columns_cache[cache_key] = columns 1082 1083 return self._get_source_columns_cache[cache_key] 1084 1085 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 1086 if self._source_columns is None: 1087 self._source_columns = { 1088 source_name: self.get_source_columns(source_name) 1089 for source_name, source in itertools.chain( 1090 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 1091 ) 1092 } 1093 return self._source_columns 1094 1095 def _get_unambiguous_columns( 1096 self, source_columns: t.Dict[str, t.Sequence[str]] 1097 ) -> t.Mapping[str, str]: 1098 """ 1099 Find all the unambiguous columns in sources. 1100 1101 Args: 1102 source_columns: Mapping of names to source columns. 1103 1104 Returns: 1105 Mapping of column name to source name. 1106 """ 1107 if not source_columns: 1108 return {} 1109 1110 source_columns_pairs = list(source_columns.items()) 1111 1112 first_table, first_columns = source_columns_pairs[0] 1113 1114 if len(source_columns_pairs) == 1: 1115 # Performance optimization - avoid copying first_columns if there is only one table. 1116 return SingleValuedMapping(first_columns, first_table) 1117 1118 unambiguous_columns = {col: first_table for col in first_columns} 1119 all_columns = set(unambiguous_columns) 1120 1121 for table, columns in source_columns_pairs[1:]: 1122 unique = set(columns) 1123 ambiguous = all_columns.intersection(unique) 1124 all_columns.update(columns) 1125 1126 for column in ambiguous: 1127 unambiguous_columns.pop(column, None) 1128 for column in unique.difference(ambiguous): 1129 unambiguous_columns[column] = table 1130 1131 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
933 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 934 self.scope = scope 935 self.schema = schema 936 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 937 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 938 self._all_columns: t.Optional[t.Set[str]] = None 939 self._infer_schema = infer_schema 940 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
942 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 943 """ 944 Get the table for a column name. 945 946 Args: 947 column_name: The column name to find the table for. 948 Returns: 949 The table name if it can be found/inferred. 950 """ 951 if self._unambiguous_columns is None: 952 self._unambiguous_columns = self._get_unambiguous_columns( 953 self._get_all_source_columns() 954 ) 955 956 table_name = self._unambiguous_columns.get(column_name) 957 958 if not table_name and self._infer_schema: 959 sources_without_schema = tuple( 960 source 961 for source, columns in self._get_all_source_columns().items() 962 if not columns or "*" in columns 963 ) 964 if len(sources_without_schema) == 1: 965 table_name = sources_without_schema[0] 966 967 if table_name not in self.scope.selected_sources: 968 return exp.to_identifier(table_name) 969 970 node, _ = self.scope.selected_sources.get(table_name) 971 972 if isinstance(node, exp.Query): 973 while node and node.alias != table_name: 974 node = node.parent 975 976 node_alias = node.args.get("alias") 977 if node_alias: 978 return exp.to_identifier(node_alias.this) 979 980 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
982 @property 983 def all_columns(self) -> t.Set[str]: 984 """All available columns of all sources in this scope""" 985 if self._all_columns is None: 986 self._all_columns = { 987 column for columns in self._get_all_source_columns().values() for column in columns 988 } 989 return self._all_columns
All available columns of all sources in this scope
991 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 992 if isinstance(expression, exp.Select): 993 return expression.named_selects 994 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 995 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 996 return self.get_source_columns_from_set_op(expression.this) 997 if not isinstance(expression, exp.SetOperation): 998 raise OptimizeError(f"Unknown set operation: {expression}") 999 1000 set_op = expression 1001 1002 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1003 on_column_list = set_op.args.get("on") 1004 1005 if on_column_list: 1006 # The resulting columns are the columns in the ON clause: 1007 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1008 columns = [col.name for col in on_column_list] 1009 elif set_op.side or set_op.kind: 1010 side = set_op.side 1011 kind = set_op.kind 1012 1013 # Visit the children UNIONs (if any) in a post-order traversal 1014 left = self.get_source_columns_from_set_op(set_op.left) 1015 right = self.get_source_columns_from_set_op(set_op.right) 1016 1017 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1018 if side == "LEFT": 1019 columns = left 1020 elif side == "FULL": 1021 columns = list(dict.fromkeys(left + right)) 1022 elif kind == "INNER": 1023 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1024 else: 1025 columns = set_op.named_selects 1026 1027 return columns
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
1029 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 1030 """Resolve the source columns for a given source `name`.""" 1031 cache_key = (name, only_visible) 1032 if cache_key not in self._get_source_columns_cache: 1033 if name not in self.scope.sources: 1034 raise OptimizeError(f"Unknown table: {name}") 1035 1036 source = self.scope.sources[name] 1037 1038 if isinstance(source, exp.Table): 1039 columns = self.schema.column_names(source, only_visible) 1040 elif isinstance(source, Scope) and isinstance( 1041 source.expression, (exp.Values, exp.Unnest) 1042 ): 1043 columns = source.expression.named_selects 1044 1045 # in bigquery, unnest structs are automatically scoped as tables, so you can 1046 # directly select a struct field in a query. 1047 # this handles the case where the unnest is statically defined. 1048 if self.schema.dialect == "bigquery": 1049 if source.expression.is_type(exp.DataType.Type.STRUCT): 1050 for k in source.expression.type.expressions: # type: ignore 1051 columns.append(k.name) 1052 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1053 columns = self.get_source_columns_from_set_op(source.expression) 1054 1055 else: 1056 select = seq_get(source.expression.selects, 0) 1057 1058 if isinstance(select, exp.QueryTransform): 1059 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1060 schema = select.args.get("schema") 1061 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1062 else: 1063 columns = source.expression.named_selects 1064 1065 node, _ = self.scope.selected_sources.get(name) or (None, None) 1066 if isinstance(node, Scope): 1067 column_aliases = node.expression.alias_column_names 1068 elif isinstance(node, exp.Expression): 1069 column_aliases = node.alias_column_names 1070 else: 1071 column_aliases = [] 1072 1073 if column_aliases: 1074 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1075 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1076 columns = [ 1077 alias or name 1078 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1079 ] 1080 1081 self._get_source_columns_cache[cache_key] = columns 1082 1083 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.