sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26 dialect: DialectType = None, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema, dialect=dialect) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 bigquery = dialect == "bigquery" 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=bigquery, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if bigquery: 111 annotator.annotate_scope(scope) 112 113 return expression 114 115 116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression 141 142 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 144 name_columns = [ 145 field.this 146 for field in unpivot.fields 147 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 148 ] 149 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 150 151 return itertools.chain(name_columns, value_columns) 152 153 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 155 """ 156 Remove table column aliases. 157 158 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 159 """ 160 for derived_table in derived_tables: 161 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 162 continue 163 table_alias = derived_table.args.get("alias") 164 if table_alias: 165 table_alias.args.pop("columns", None) 166 167 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 169 columns = {} 170 171 def _update_source_columns(source_name: str) -> None: 172 for column_name in resolver.get_source_columns(source_name): 173 if column_name not in columns: 174 columns[column_name] = source_name 175 176 joins = list(scope.find_all(exp.Join)) 177 names = {join.alias_or_name for join in joins} 178 ordered = [key for key in scope.selected_sources if key not in names] 179 180 if names and not ordered: 181 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 182 183 # Mapping of automatically joined column names to an ordered set of source names (dict). 184 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 185 186 for source_name in ordered: 187 _update_source_columns(source_name) 188 189 for i, join in enumerate(joins): 190 source_table = ordered[-1] 191 if source_table: 192 _update_source_columns(source_table) 193 194 join_table = join.alias_or_name 195 ordered.append(join_table) 196 197 using = join.args.get("using") 198 if not using: 199 continue 200 201 join_columns = resolver.get_source_columns(join_table) 202 conditions = [] 203 using_identifier_count = len(using) 204 is_semi_or_anti_join = join.is_semi_or_anti_join 205 206 for identifier in using: 207 identifier = identifier.name 208 table = columns.get(identifier) 209 210 if not table or identifier not in join_columns: 211 if (columns and "*" not in columns) and join_columns: 212 raise OptimizeError(f"Cannot automatically join: {identifier}") 213 214 table = table or source_table 215 216 if i == 0 or using_identifier_count == 1: 217 lhs: exp.Expression = exp.column(identifier, table=table) 218 else: 219 coalesce_columns = [ 220 exp.column(identifier, table=t) 221 for t in ordered[:-1] 222 if identifier in resolver.get_source_columns(t) 223 ] 224 if len(coalesce_columns) > 1: 225 lhs = exp.func("coalesce", *coalesce_columns) 226 else: 227 lhs = exp.column(identifier, table=table) 228 229 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 230 231 # Set all values in the dict to None, because we only care about the key ordering 232 tables = column_tables.setdefault(identifier, {}) 233 234 # Do not update the dict if this was a SEMI/ANTI join in 235 # order to avoid generating COALESCE columns for this join pair 236 if not is_semi_or_anti_join: 237 if table not in tables: 238 tables[table] = None 239 if join_table not in tables: 240 tables[join_table] = None 241 242 join.args.pop("using") 243 join.set("on", exp.and_(*conditions, copy=False)) 244 245 if column_tables: 246 for column in scope.columns: 247 if not column.table and column.name in column_tables: 248 tables = column_tables[column.name] 249 coalesce_args = [exp.column(column.name, table=table) for table in tables] 250 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 251 252 if isinstance(column.parent, exp.Select): 253 # Ensure the USING column keeps its name if it's projected 254 replacement = alias(replacement, alias=column.name, copy=False) 255 elif isinstance(column.parent, exp.Struct): 256 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 257 replacement = exp.PropertyEQ( 258 this=exp.to_identifier(column.name), expression=replacement 259 ) 260 261 scope.replace(column, replacement) 262 263 return column_tables 264 265 266def _expand_alias_refs( 267 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 268) -> None: 269 """ 270 Expand references to aliases. 271 Example: 272 SELECT y.foo AS bar, bar * 2 AS baz FROM y 273 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 274 """ 275 expression = scope.expression 276 277 if not isinstance(expression, exp.Select) or dialect == "oracle": 278 return 279 280 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 281 projections = {s.alias_or_name for s in expression.selects} 282 is_bigquery = dialect == "bigquery" 283 284 def replace_columns( 285 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 286 ) -> None: 287 is_group_by = isinstance(node, exp.Group) 288 is_having = isinstance(node, exp.Having) 289 if not node or (expand_only_groupby and not is_group_by): 290 return 291 292 for column in walk_in_scope(node, prune=lambda node: node.is_star): 293 if not isinstance(column, exp.Column): 294 continue 295 296 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 297 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 298 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 299 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 300 if expand_only_groupby and is_group_by and column.parent is not node: 301 continue 302 303 skip_replace = False 304 table = resolver.get_table(column.name) if resolve_table and not column.table else None 305 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 306 307 if alias_expr: 308 skip_replace = bool( 309 alias_expr.find(exp.AggFunc) 310 and column.find_ancestor(exp.AggFunc) 311 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 312 ) 313 314 # BigQuery's having clause gets confused if an alias matches a source. 315 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 316 # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b) 317 # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed" 318 if is_having and is_bigquery: 319 skip_replace = skip_replace or any( 320 node.parts[0].name in projections 321 for node in alias_expr.find_all(exp.Column) 322 ) 323 elif is_bigquery and (is_group_by or is_having): 324 column_table = table.name if table else column.table 325 if column_table in projections: 326 # BigQuery's GROUP BY and HAVING clauses get confused if the column name 327 # matches a source name and a projection. For instance: 328 # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1 329 # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table 330 # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause" 331 column.replace(exp.to_identifier(column.name)) 332 return 333 334 if table and (not alias_expr or skip_replace): 335 column.set("table", table) 336 elif not column.table and alias_expr and not skip_replace: 337 if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and ( 338 literal_index or resolve_table 339 ): 340 if literal_index: 341 column.replace(exp.Literal.number(i)) 342 else: 343 column = column.replace(exp.paren(alias_expr)) 344 simplified = simplify_parens(column, dialect) 345 if simplified is not column: 346 column.replace(simplified) 347 348 for i, projection in enumerate(expression.selects): 349 replace_columns(projection) 350 if isinstance(projection, exp.Alias): 351 alias_to_expression[projection.alias] = (projection.this, i + 1) 352 353 parent_scope = scope 354 on_right_sub_tree = False 355 while parent_scope and not parent_scope.is_cte: 356 if parent_scope.is_union: 357 on_right_sub_tree = parent_scope.parent.expression.right is parent_scope.expression 358 parent_scope = parent_scope.parent 359 360 # We shouldn't expand aliases if they match the recursive CTE's columns 361 # and we are in the recursive part (right sub tree) of the CTE 362 if parent_scope and on_right_sub_tree: 363 cte = parent_scope.expression.parent 364 if cte.find_ancestor(exp.With).recursive: 365 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 366 alias_to_expression.pop(recursive_cte_column.output_name, None) 367 368 replace_columns(expression.args.get("where")) 369 replace_columns(expression.args.get("group"), literal_index=True) 370 replace_columns(expression.args.get("having"), resolve_table=True) 371 replace_columns(expression.args.get("qualify"), resolve_table=True) 372 373 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 374 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 375 if dialect == "snowflake": 376 for join in expression.args.get("joins") or []: 377 replace_columns(join) 378 379 scope.clear_cache() 380 381 382def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 383 expression = scope.expression 384 group = expression.args.get("group") 385 if not group: 386 return 387 388 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 389 expression.set("group", group) 390 391 392def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 393 for modifier_key in ("order", "distinct"): 394 modifier = scope.expression.args.get(modifier_key) 395 if isinstance(modifier, exp.Distinct): 396 modifier = modifier.args.get("on") 397 398 if not isinstance(modifier, exp.Expression): 399 continue 400 401 modifier_expressions = modifier.expressions 402 if modifier_key == "order": 403 modifier_expressions = [ordered.this for ordered in modifier_expressions] 404 405 for original, expanded in zip( 406 modifier_expressions, 407 _expand_positional_references( 408 scope, modifier_expressions, resolver.schema.dialect, alias=True 409 ), 410 ): 411 for agg in original.find_all(exp.AggFunc): 412 for col in agg.find_all(exp.Column): 413 if not col.table: 414 col.set("table", resolver.get_table(col.name)) 415 416 original.replace(expanded) 417 418 if scope.expression.args.get("group"): 419 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 420 421 for expression in modifier_expressions: 422 expression.replace( 423 exp.to_identifier(_select_by_pos(scope, expression).alias) 424 if expression.is_int 425 else selects.get(expression, expression) 426 ) 427 428 429def _expand_positional_references( 430 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 431) -> t.List[exp.Expression]: 432 new_nodes: t.List[exp.Expression] = [] 433 ambiguous_projections = None 434 435 for node in expressions: 436 if node.is_int: 437 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 438 439 if alias: 440 new_nodes.append(exp.column(select.args["alias"].copy())) 441 else: 442 select = select.this 443 444 if dialect == "bigquery": 445 if ambiguous_projections is None: 446 # When a projection name is also a source name and it is referenced in the 447 # GROUP BY clause, BQ can't understand what the identifier corresponds to 448 ambiguous_projections = { 449 s.alias_or_name 450 for s in scope.expression.selects 451 if s.alias_or_name in scope.selected_sources 452 } 453 454 ambiguous = any( 455 column.parts[0].name in ambiguous_projections 456 for column in select.find_all(exp.Column) 457 ) 458 else: 459 ambiguous = False 460 461 if ( 462 isinstance(select, exp.CONSTANTS) 463 or select.is_number 464 or select.find(exp.Explode, exp.Unnest) 465 or ambiguous 466 ): 467 new_nodes.append(node) 468 else: 469 new_nodes.append(select.copy()) 470 else: 471 new_nodes.append(node) 472 473 return new_nodes 474 475 476def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 477 try: 478 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 479 except IndexError: 480 raise OptimizeError(f"Unknown output column: {node.name}") 481 482 483def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 484 """ 485 Converts `Column` instances that represent struct field lookup into chained `Dots`. 486 487 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 488 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 489 """ 490 converted = False 491 for column in itertools.chain(scope.columns, scope.stars): 492 if isinstance(column, exp.Dot): 493 continue 494 495 column_table: t.Optional[str | exp.Identifier] = column.table 496 if ( 497 column_table 498 and column_table not in scope.sources 499 and ( 500 not scope.parent 501 or column_table not in scope.parent.sources 502 or not scope.is_correlated_subquery 503 ) 504 ): 505 root, *parts = column.parts 506 507 if root.name in scope.sources: 508 # The struct is already qualified, but we still need to change the AST 509 column_table = root 510 root, *parts = parts 511 else: 512 column_table = resolver.get_table(root.name) 513 514 if column_table: 515 converted = True 516 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 517 518 if converted: 519 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 520 # a `for column in scope.columns` iteration, even though they shouldn't be 521 scope.clear_cache() 522 523 524def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 525 """Disambiguate columns, ensuring each column specifies a source""" 526 for column in scope.columns: 527 column_table = column.table 528 column_name = column.name 529 530 if column_table and column_table in scope.sources: 531 source_columns = resolver.get_source_columns(column_table) 532 if ( 533 not allow_partial_qualification 534 and source_columns 535 and column_name not in source_columns 536 and "*" not in source_columns 537 ): 538 raise OptimizeError(f"Unknown column: {column_name}") 539 540 if not column_table: 541 if scope.pivots and not column.find_ancestor(exp.Pivot): 542 # If the column is under the Pivot expression, we need to qualify it 543 # using the name of the pivoted source instead of the pivot's alias 544 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 545 continue 546 547 # column_table can be a '' because bigquery unnest has no table alias 548 column_table = resolver.get_table(column_name) 549 if column_table: 550 column.set("table", column_table) 551 elif ( 552 resolver.schema.dialect == "bigquery" 553 and len(column.parts) == 1 554 and column_name in scope.selected_sources 555 ): 556 # BigQuery allows tables to be referenced as columns, treating them as structs 557 scope.replace(column, exp.TableColumn(this=column.this)) 558 559 for pivot in scope.pivots: 560 for column in pivot.find_all(exp.Column): 561 if not column.table and column.name in resolver.all_columns: 562 column_table = resolver.get_table(column.name) 563 if column_table: 564 column.set("table", column_table) 565 566 567def _expand_struct_stars_bigquery( 568 expression: exp.Dot, 569) -> t.List[exp.Alias]: 570 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 571 572 dot_column = expression.find(exp.Column) 573 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 574 return [] 575 576 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 577 dot_column = dot_column.copy() 578 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 579 580 # First part is the table name and last part is the star so they can be dropped 581 dot_parts = expression.parts[1:-1] 582 583 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 584 for part in dot_parts[1:]: 585 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 586 # Unable to expand star unless all fields are named 587 if not isinstance(field.this, exp.Identifier): 588 return [] 589 590 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 591 starting_struct = field 592 break 593 else: 594 # There is no matching field in the struct 595 return [] 596 597 taken_names = set() 598 new_selections = [] 599 600 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 601 name = field.name 602 603 # Ambiguous or anonymous fields can't be expanded 604 if name in taken_names or not isinstance(field.this, exp.Identifier): 605 return [] 606 607 taken_names.add(name) 608 609 this = field.this.copy() 610 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 611 new_column = exp.column( 612 t.cast(exp.Identifier, root), 613 table=dot_column.args.get("table"), 614 fields=t.cast(t.List[exp.Identifier], parts), 615 ) 616 new_selections.append(alias(new_column, this, copy=False)) 617 618 return new_selections 619 620 621def _expand_struct_stars_risingwave(expression: exp.Dot) -> t.List[exp.Alias]: 622 """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column""" 623 624 # it is not (<sub_exp>).* pattern, which means we can't expand 625 if not isinstance(expression.this, exp.Paren): 626 return [] 627 628 # find column definition to get data-type 629 dot_column = expression.find(exp.Column) 630 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 631 return [] 632 633 parent = dot_column.parent 634 starting_struct = dot_column.type 635 636 # walk up AST and down into struct definition in sync 637 while parent is not None: 638 if isinstance(parent, exp.Paren): 639 parent = parent.parent 640 continue 641 642 # if parent is not a dot, then something is wrong 643 if not isinstance(parent, exp.Dot): 644 return [] 645 646 # if the rhs of the dot is star we are done 647 rhs = parent.right 648 if isinstance(rhs, exp.Star): 649 break 650 651 # if it is not identifier, then something is wrong 652 if not isinstance(rhs, exp.Identifier): 653 return [] 654 655 # Check if current rhs identifier is in struct 656 matched = False 657 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 658 if struct_field_def.name == rhs.name: 659 matched = True 660 starting_struct = struct_field_def.kind # update struct 661 break 662 663 if not matched: 664 return [] 665 666 parent = parent.parent 667 668 # build new aliases to expand star 669 new_selections = [] 670 671 # fetch the outermost parentheses for new aliaes 672 outer_paren = expression.this 673 674 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 675 new_identifier = struct_field_def.this.copy() 676 new_dot = exp.Dot.build([outer_paren.copy(), new_identifier]) 677 new_alias = alias(new_dot, new_identifier, copy=False) 678 new_selections.append(new_alias) 679 680 return new_selections 681 682 683def _expand_stars( 684 scope: Scope, 685 resolver: Resolver, 686 using_column_tables: t.Dict[str, t.Any], 687 pseudocolumns: t.Set[str], 688 annotator: TypeAnnotator, 689) -> None: 690 """Expand stars to lists of column selections""" 691 692 new_selections: t.List[exp.Expression] = [] 693 except_columns: t.Dict[int, t.Set[str]] = {} 694 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 695 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 696 697 coalesced_columns = set() 698 dialect = resolver.schema.dialect 699 700 pivot_output_columns = None 701 pivot_exclude_columns: t.Set[str] = set() 702 703 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 704 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 705 if pivot.unpivot: 706 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 707 708 for field in pivot.fields: 709 if isinstance(field, exp.In): 710 pivot_exclude_columns.update( 711 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 712 ) 713 714 else: 715 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 716 717 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 718 if not pivot_output_columns: 719 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 720 721 is_bigquery = dialect == "bigquery" 722 is_risingwave = dialect == "risingwave" 723 724 if (is_bigquery or is_risingwave) and any(isinstance(col, exp.Dot) for col in scope.stars): 725 # Found struct expansion, annotate scope ahead of time 726 annotator.annotate_scope(scope) 727 728 for expression in scope.expression.selects: 729 tables = [] 730 if isinstance(expression, exp.Star): 731 tables.extend(scope.selected_sources) 732 _add_except_columns(expression, tables, except_columns) 733 _add_replace_columns(expression, tables, replace_columns) 734 _add_rename_columns(expression, tables, rename_columns) 735 elif expression.is_star: 736 if not isinstance(expression, exp.Dot): 737 tables.append(expression.table) 738 _add_except_columns(expression.this, tables, except_columns) 739 _add_replace_columns(expression.this, tables, replace_columns) 740 _add_rename_columns(expression.this, tables, rename_columns) 741 elif is_bigquery: 742 struct_fields = _expand_struct_stars_bigquery(expression) 743 if struct_fields: 744 new_selections.extend(struct_fields) 745 continue 746 elif is_risingwave: 747 struct_fields = _expand_struct_stars_risingwave(expression) 748 if struct_fields: 749 new_selections.extend(struct_fields) 750 continue 751 752 if not tables: 753 new_selections.append(expression) 754 continue 755 756 for table in tables: 757 if table not in scope.sources: 758 raise OptimizeError(f"Unknown table: {table}") 759 760 columns = resolver.get_source_columns(table, only_visible=True) 761 columns = columns or scope.outer_columns 762 763 if pseudocolumns: 764 columns = [name for name in columns if name.upper() not in pseudocolumns] 765 766 if not columns or "*" in columns: 767 return 768 769 table_id = id(table) 770 columns_to_exclude = except_columns.get(table_id) or set() 771 renamed_columns = rename_columns.get(table_id, {}) 772 replaced_columns = replace_columns.get(table_id, {}) 773 774 if pivot: 775 if pivot_output_columns and pivot_exclude_columns: 776 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 777 pivot_columns.extend(pivot_output_columns) 778 else: 779 pivot_columns = pivot.alias_column_names 780 781 if pivot_columns: 782 new_selections.extend( 783 alias(exp.column(name, table=pivot.alias), name, copy=False) 784 for name in pivot_columns 785 if name not in columns_to_exclude 786 ) 787 continue 788 789 for name in columns: 790 if name in columns_to_exclude or name in coalesced_columns: 791 continue 792 if name in using_column_tables and table in using_column_tables[name]: 793 coalesced_columns.add(name) 794 tables = using_column_tables[name] 795 coalesce_args = [exp.column(name, table=table) for table in tables] 796 797 new_selections.append( 798 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 799 ) 800 else: 801 alias_ = renamed_columns.get(name, name) 802 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 803 new_selections.append( 804 alias(selection_expr, alias_, copy=False) 805 if alias_ != name 806 else selection_expr 807 ) 808 809 # Ensures we don't overwrite the initial selections with an empty list 810 if new_selections and isinstance(scope.expression, exp.Select): 811 scope.expression.set("expressions", new_selections) 812 813 814def _add_except_columns( 815 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 816) -> None: 817 except_ = expression.args.get("except") 818 819 if not except_: 820 return 821 822 columns = {e.name for e in except_} 823 824 for table in tables: 825 except_columns[id(table)] = columns 826 827 828def _add_rename_columns( 829 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 830) -> None: 831 rename = expression.args.get("rename") 832 833 if not rename: 834 return 835 836 columns = {e.this.name: e.alias for e in rename} 837 838 for table in tables: 839 rename_columns[id(table)] = columns 840 841 842def _add_replace_columns( 843 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 844) -> None: 845 replace = expression.args.get("replace") 846 847 if not replace: 848 return 849 850 columns = {e.alias: e for e in replace} 851 852 for table in tables: 853 replace_columns[id(table)] = columns 854 855 856def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 857 """Ensure all output columns are aliased""" 858 if isinstance(scope_or_expression, exp.Expression): 859 scope = build_scope(scope_or_expression) 860 if not isinstance(scope, Scope): 861 return 862 else: 863 scope = scope_or_expression 864 865 new_selections = [] 866 for i, (selection, aliased_column) in enumerate( 867 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 868 ): 869 if selection is None or isinstance(selection, exp.QueryTransform): 870 break 871 872 if isinstance(selection, exp.Subquery): 873 if not selection.output_name: 874 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 875 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 876 selection = alias( 877 selection, 878 alias=selection.output_name or f"_col_{i}", 879 copy=False, 880 ) 881 if aliased_column: 882 selection.set("alias", exp.to_identifier(aliased_column)) 883 884 new_selections.append(selection) 885 886 if new_selections and isinstance(scope.expression, exp.Select): 887 scope.expression.set("expressions", new_selections) 888 889 890def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 891 """Makes sure all identifiers that need to be quoted are quoted.""" 892 return expression.transform( 893 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 894 ) # type: ignore 895 896 897def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 898 """ 899 Pushes down the CTE alias columns into the projection, 900 901 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 902 903 Example: 904 >>> import sqlglot 905 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 906 >>> pushdown_cte_alias_columns(expression).sql() 907 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 908 909 Args: 910 expression: Expression to pushdown. 911 912 Returns: 913 The expression with the CTE aliases pushed down into the projection. 914 """ 915 for cte in expression.find_all(exp.CTE): 916 if cte.alias_column_names: 917 new_expressions = [] 918 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 919 if isinstance(projection, exp.Alias): 920 projection.set("alias", _alias) 921 else: 922 projection = alias(projection, alias=_alias) 923 new_expressions.append(projection) 924 cte.this.set("expressions", new_expressions) 925 926 return expression 927 928 929class Resolver: 930 """ 931 Helper for resolving columns. 932 933 This is a class so we can lazily load some things and easily share them across functions. 934 """ 935 936 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 937 self.scope = scope 938 self.schema = schema 939 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 940 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 941 self._all_columns: t.Optional[t.Set[str]] = None 942 self._infer_schema = infer_schema 943 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 944 945 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 946 """ 947 Get the table for a column name. 948 949 Args: 950 column_name: The column name to find the table for. 951 Returns: 952 The table name if it can be found/inferred. 953 """ 954 if self._unambiguous_columns is None: 955 self._unambiguous_columns = self._get_unambiguous_columns( 956 self._get_all_source_columns() 957 ) 958 959 table_name = self._unambiguous_columns.get(column_name) 960 961 if not table_name and self._infer_schema: 962 sources_without_schema = tuple( 963 source 964 for source, columns in self._get_all_source_columns().items() 965 if not columns or "*" in columns 966 ) 967 if len(sources_without_schema) == 1: 968 table_name = sources_without_schema[0] 969 970 if table_name not in self.scope.selected_sources: 971 return exp.to_identifier(table_name) 972 973 node, _ = self.scope.selected_sources.get(table_name) 974 975 if isinstance(node, exp.Query): 976 while node and node.alias != table_name: 977 node = node.parent 978 979 node_alias = node.args.get("alias") 980 if node_alias: 981 return exp.to_identifier(node_alias.this) 982 983 return exp.to_identifier(table_name) 984 985 @property 986 def all_columns(self) -> t.Set[str]: 987 """All available columns of all sources in this scope""" 988 if self._all_columns is None: 989 self._all_columns = { 990 column for columns in self._get_all_source_columns().values() for column in columns 991 } 992 return self._all_columns 993 994 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 995 if isinstance(expression, exp.Select): 996 return expression.named_selects 997 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 998 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 999 return self.get_source_columns_from_set_op(expression.this) 1000 if not isinstance(expression, exp.SetOperation): 1001 raise OptimizeError(f"Unknown set operation: {expression}") 1002 1003 set_op = expression 1004 1005 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1006 on_column_list = set_op.args.get("on") 1007 1008 if on_column_list: 1009 # The resulting columns are the columns in the ON clause: 1010 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1011 columns = [col.name for col in on_column_list] 1012 elif set_op.side or set_op.kind: 1013 side = set_op.side 1014 kind = set_op.kind 1015 1016 # Visit the children UNIONs (if any) in a post-order traversal 1017 left = self.get_source_columns_from_set_op(set_op.left) 1018 right = self.get_source_columns_from_set_op(set_op.right) 1019 1020 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1021 if side == "LEFT": 1022 columns = left 1023 elif side == "FULL": 1024 columns = list(dict.fromkeys(left + right)) 1025 elif kind == "INNER": 1026 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1027 else: 1028 columns = set_op.named_selects 1029 1030 return columns 1031 1032 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 1033 """Resolve the source columns for a given source `name`.""" 1034 cache_key = (name, only_visible) 1035 if cache_key not in self._get_source_columns_cache: 1036 if name not in self.scope.sources: 1037 raise OptimizeError(f"Unknown table: {name}") 1038 1039 source = self.scope.sources[name] 1040 1041 if isinstance(source, exp.Table): 1042 columns = self.schema.column_names(source, only_visible) 1043 elif isinstance(source, Scope) and isinstance( 1044 source.expression, (exp.Values, exp.Unnest) 1045 ): 1046 columns = source.expression.named_selects 1047 1048 # in bigquery, unnest structs are automatically scoped as tables, so you can 1049 # directly select a struct field in a query. 1050 # this handles the case where the unnest is statically defined. 1051 if self.schema.dialect == "bigquery": 1052 if source.expression.is_type(exp.DataType.Type.STRUCT): 1053 for k in source.expression.type.expressions: # type: ignore 1054 columns.append(k.name) 1055 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1056 columns = self.get_source_columns_from_set_op(source.expression) 1057 1058 else: 1059 select = seq_get(source.expression.selects, 0) 1060 1061 if isinstance(select, exp.QueryTransform): 1062 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1063 schema = select.args.get("schema") 1064 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1065 else: 1066 columns = source.expression.named_selects 1067 1068 node, _ = self.scope.selected_sources.get(name) or (None, None) 1069 if isinstance(node, Scope): 1070 column_aliases = node.expression.alias_column_names 1071 elif isinstance(node, exp.Expression): 1072 column_aliases = node.alias_column_names 1073 else: 1074 column_aliases = [] 1075 1076 if column_aliases: 1077 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1078 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1079 columns = [ 1080 alias or name 1081 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1082 ] 1083 1084 self._get_source_columns_cache[cache_key] = columns 1085 1086 return self._get_source_columns_cache[cache_key] 1087 1088 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 1089 if self._source_columns is None: 1090 self._source_columns = { 1091 source_name: self.get_source_columns(source_name) 1092 for source_name, source in itertools.chain( 1093 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 1094 ) 1095 } 1096 return self._source_columns 1097 1098 def _get_unambiguous_columns( 1099 self, source_columns: t.Dict[str, t.Sequence[str]] 1100 ) -> t.Mapping[str, str]: 1101 """ 1102 Find all the unambiguous columns in sources. 1103 1104 Args: 1105 source_columns: Mapping of names to source columns. 1106 1107 Returns: 1108 Mapping of column name to source name. 1109 """ 1110 if not source_columns: 1111 return {} 1112 1113 source_columns_pairs = list(source_columns.items()) 1114 1115 first_table, first_columns = source_columns_pairs[0] 1116 1117 if len(source_columns_pairs) == 1: 1118 # Performance optimization - avoid copying first_columns if there is only one table. 1119 return SingleValuedMapping(first_columns, first_table) 1120 1121 unambiguous_columns = {col: first_table for col in first_columns} 1122 all_columns = set(unambiguous_columns) 1123 1124 for table, columns in source_columns_pairs[1:]: 1125 unique = set(columns) 1126 ambiguous = all_columns.intersection(unique) 1127 all_columns.update(columns) 1128 1129 for column in ambiguous: 1130 unambiguous_columns.pop(column, None) 1131 for column in unique.difference(ambiguous): 1132 unambiguous_columns[column] = table 1133 1134 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27 dialect: DialectType = None, 28) -> exp.Expression: 29 """ 30 Rewrite sqlglot AST to have fully qualified columns. 31 32 Example: 33 >>> import sqlglot 34 >>> schema = {"tbl": {"col": "INT"}} 35 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 36 >>> qualify_columns(expression, schema).sql() 37 'SELECT tbl.col AS col FROM tbl' 38 39 Args: 40 expression: Expression to qualify. 41 schema: Database schema. 42 expand_alias_refs: Whether to expand references to aliases. 43 expand_stars: Whether to expand star queries. This is a necessary step 44 for most of the optimizer's rules to work; do not set to False unless you 45 know what you're doing! 46 infer_schema: Whether to infer the schema if missing. 47 allow_partial_qualification: Whether to allow partial qualification. 48 49 Returns: 50 The qualified expression. 51 52 Notes: 53 - Currently only handles a single PIVOT or UNPIVOT operator 54 """ 55 schema = ensure_schema(schema, dialect=dialect) 56 annotator = TypeAnnotator(schema) 57 infer_schema = schema.empty if infer_schema is None else infer_schema 58 dialect = Dialect.get_or_raise(schema.dialect) 59 pseudocolumns = dialect.PSEUDOCOLUMNS 60 bigquery = dialect == "bigquery" 61 62 for scope in traverse_scope(expression): 63 scope_expression = scope.expression 64 is_select = isinstance(scope_expression, exp.Select) 65 66 if is_select and scope_expression.args.get("connect"): 67 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 68 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 69 scope_expression.transform( 70 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 71 copy=False, 72 ) 73 scope.clear_cache() 74 75 resolver = Resolver(scope, schema, infer_schema=infer_schema) 76 _pop_table_column_aliases(scope.ctes) 77 _pop_table_column_aliases(scope.derived_tables) 78 using_column_tables = _expand_using(scope, resolver) 79 80 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 81 _expand_alias_refs( 82 scope, 83 resolver, 84 dialect, 85 expand_only_groupby=bigquery, 86 ) 87 88 _convert_columns_to_dots(scope, resolver) 89 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if bigquery: 112 annotator.annotate_scope(scope) 113 114 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 128 129 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 130 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 131 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 132 # this list here to ensure those in the former category will be excluded. 133 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 134 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 135 136 all_unqualified_columns.extend(unqualified_columns) 137 138 if all_unqualified_columns: 139 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 140 141 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
857def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 858 """Ensure all output columns are aliased""" 859 if isinstance(scope_or_expression, exp.Expression): 860 scope = build_scope(scope_or_expression) 861 if not isinstance(scope, Scope): 862 return 863 else: 864 scope = scope_or_expression 865 866 new_selections = [] 867 for i, (selection, aliased_column) in enumerate( 868 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 869 ): 870 if selection is None or isinstance(selection, exp.QueryTransform): 871 break 872 873 if isinstance(selection, exp.Subquery): 874 if not selection.output_name: 875 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 876 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 877 selection = alias( 878 selection, 879 alias=selection.output_name or f"_col_{i}", 880 copy=False, 881 ) 882 if aliased_column: 883 selection.set("alias", exp.to_identifier(aliased_column)) 884 885 new_selections.append(selection) 886 887 if new_selections and isinstance(scope.expression, exp.Select): 888 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
891def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 892 """Makes sure all identifiers that need to be quoted are quoted.""" 893 return expression.transform( 894 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 895 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
898def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 899 """ 900 Pushes down the CTE alias columns into the projection, 901 902 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 903 904 Example: 905 >>> import sqlglot 906 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 907 >>> pushdown_cte_alias_columns(expression).sql() 908 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 909 910 Args: 911 expression: Expression to pushdown. 912 913 Returns: 914 The expression with the CTE aliases pushed down into the projection. 915 """ 916 for cte in expression.find_all(exp.CTE): 917 if cte.alias_column_names: 918 new_expressions = [] 919 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 920 if isinstance(projection, exp.Alias): 921 projection.set("alias", _alias) 922 else: 923 projection = alias(projection, alias=_alias) 924 new_expressions.append(projection) 925 cte.this.set("expressions", new_expressions) 926 927 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
930class Resolver: 931 """ 932 Helper for resolving columns. 933 934 This is a class so we can lazily load some things and easily share them across functions. 935 """ 936 937 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 938 self.scope = scope 939 self.schema = schema 940 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 941 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 942 self._all_columns: t.Optional[t.Set[str]] = None 943 self._infer_schema = infer_schema 944 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 945 946 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 947 """ 948 Get the table for a column name. 949 950 Args: 951 column_name: The column name to find the table for. 952 Returns: 953 The table name if it can be found/inferred. 954 """ 955 if self._unambiguous_columns is None: 956 self._unambiguous_columns = self._get_unambiguous_columns( 957 self._get_all_source_columns() 958 ) 959 960 table_name = self._unambiguous_columns.get(column_name) 961 962 if not table_name and self._infer_schema: 963 sources_without_schema = tuple( 964 source 965 for source, columns in self._get_all_source_columns().items() 966 if not columns or "*" in columns 967 ) 968 if len(sources_without_schema) == 1: 969 table_name = sources_without_schema[0] 970 971 if table_name not in self.scope.selected_sources: 972 return exp.to_identifier(table_name) 973 974 node, _ = self.scope.selected_sources.get(table_name) 975 976 if isinstance(node, exp.Query): 977 while node and node.alias != table_name: 978 node = node.parent 979 980 node_alias = node.args.get("alias") 981 if node_alias: 982 return exp.to_identifier(node_alias.this) 983 984 return exp.to_identifier(table_name) 985 986 @property 987 def all_columns(self) -> t.Set[str]: 988 """All available columns of all sources in this scope""" 989 if self._all_columns is None: 990 self._all_columns = { 991 column for columns in self._get_all_source_columns().values() for column in columns 992 } 993 return self._all_columns 994 995 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 996 if isinstance(expression, exp.Select): 997 return expression.named_selects 998 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 999 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 1000 return self.get_source_columns_from_set_op(expression.this) 1001 if not isinstance(expression, exp.SetOperation): 1002 raise OptimizeError(f"Unknown set operation: {expression}") 1003 1004 set_op = expression 1005 1006 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1007 on_column_list = set_op.args.get("on") 1008 1009 if on_column_list: 1010 # The resulting columns are the columns in the ON clause: 1011 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1012 columns = [col.name for col in on_column_list] 1013 elif set_op.side or set_op.kind: 1014 side = set_op.side 1015 kind = set_op.kind 1016 1017 # Visit the children UNIONs (if any) in a post-order traversal 1018 left = self.get_source_columns_from_set_op(set_op.left) 1019 right = self.get_source_columns_from_set_op(set_op.right) 1020 1021 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1022 if side == "LEFT": 1023 columns = left 1024 elif side == "FULL": 1025 columns = list(dict.fromkeys(left + right)) 1026 elif kind == "INNER": 1027 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1028 else: 1029 columns = set_op.named_selects 1030 1031 return columns 1032 1033 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 1034 """Resolve the source columns for a given source `name`.""" 1035 cache_key = (name, only_visible) 1036 if cache_key not in self._get_source_columns_cache: 1037 if name not in self.scope.sources: 1038 raise OptimizeError(f"Unknown table: {name}") 1039 1040 source = self.scope.sources[name] 1041 1042 if isinstance(source, exp.Table): 1043 columns = self.schema.column_names(source, only_visible) 1044 elif isinstance(source, Scope) and isinstance( 1045 source.expression, (exp.Values, exp.Unnest) 1046 ): 1047 columns = source.expression.named_selects 1048 1049 # in bigquery, unnest structs are automatically scoped as tables, so you can 1050 # directly select a struct field in a query. 1051 # this handles the case where the unnest is statically defined. 1052 if self.schema.dialect == "bigquery": 1053 if source.expression.is_type(exp.DataType.Type.STRUCT): 1054 for k in source.expression.type.expressions: # type: ignore 1055 columns.append(k.name) 1056 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1057 columns = self.get_source_columns_from_set_op(source.expression) 1058 1059 else: 1060 select = seq_get(source.expression.selects, 0) 1061 1062 if isinstance(select, exp.QueryTransform): 1063 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1064 schema = select.args.get("schema") 1065 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1066 else: 1067 columns = source.expression.named_selects 1068 1069 node, _ = self.scope.selected_sources.get(name) or (None, None) 1070 if isinstance(node, Scope): 1071 column_aliases = node.expression.alias_column_names 1072 elif isinstance(node, exp.Expression): 1073 column_aliases = node.alias_column_names 1074 else: 1075 column_aliases = [] 1076 1077 if column_aliases: 1078 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1079 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1080 columns = [ 1081 alias or name 1082 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1083 ] 1084 1085 self._get_source_columns_cache[cache_key] = columns 1086 1087 return self._get_source_columns_cache[cache_key] 1088 1089 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 1090 if self._source_columns is None: 1091 self._source_columns = { 1092 source_name: self.get_source_columns(source_name) 1093 for source_name, source in itertools.chain( 1094 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 1095 ) 1096 } 1097 return self._source_columns 1098 1099 def _get_unambiguous_columns( 1100 self, source_columns: t.Dict[str, t.Sequence[str]] 1101 ) -> t.Mapping[str, str]: 1102 """ 1103 Find all the unambiguous columns in sources. 1104 1105 Args: 1106 source_columns: Mapping of names to source columns. 1107 1108 Returns: 1109 Mapping of column name to source name. 1110 """ 1111 if not source_columns: 1112 return {} 1113 1114 source_columns_pairs = list(source_columns.items()) 1115 1116 first_table, first_columns = source_columns_pairs[0] 1117 1118 if len(source_columns_pairs) == 1: 1119 # Performance optimization - avoid copying first_columns if there is only one table. 1120 return SingleValuedMapping(first_columns, first_table) 1121 1122 unambiguous_columns = {col: first_table for col in first_columns} 1123 all_columns = set(unambiguous_columns) 1124 1125 for table, columns in source_columns_pairs[1:]: 1126 unique = set(columns) 1127 ambiguous = all_columns.intersection(unique) 1128 all_columns.update(columns) 1129 1130 for column in ambiguous: 1131 unambiguous_columns.pop(column, None) 1132 for column in unique.difference(ambiguous): 1133 unambiguous_columns[column] = table 1134 1135 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
937 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 938 self.scope = scope 939 self.schema = schema 940 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 941 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 942 self._all_columns: t.Optional[t.Set[str]] = None 943 self._infer_schema = infer_schema 944 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
946 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 947 """ 948 Get the table for a column name. 949 950 Args: 951 column_name: The column name to find the table for. 952 Returns: 953 The table name if it can be found/inferred. 954 """ 955 if self._unambiguous_columns is None: 956 self._unambiguous_columns = self._get_unambiguous_columns( 957 self._get_all_source_columns() 958 ) 959 960 table_name = self._unambiguous_columns.get(column_name) 961 962 if not table_name and self._infer_schema: 963 sources_without_schema = tuple( 964 source 965 for source, columns in self._get_all_source_columns().items() 966 if not columns or "*" in columns 967 ) 968 if len(sources_without_schema) == 1: 969 table_name = sources_without_schema[0] 970 971 if table_name not in self.scope.selected_sources: 972 return exp.to_identifier(table_name) 973 974 node, _ = self.scope.selected_sources.get(table_name) 975 976 if isinstance(node, exp.Query): 977 while node and node.alias != table_name: 978 node = node.parent 979 980 node_alias = node.args.get("alias") 981 if node_alias: 982 return exp.to_identifier(node_alias.this) 983 984 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
986 @property 987 def all_columns(self) -> t.Set[str]: 988 """All available columns of all sources in this scope""" 989 if self._all_columns is None: 990 self._all_columns = { 991 column for columns in self._get_all_source_columns().values() for column in columns 992 } 993 return self._all_columns
All available columns of all sources in this scope
995 def get_source_columns_from_set_op(self, expression: exp.Expression) -> t.List[str]: 996 if isinstance(expression, exp.Select): 997 return expression.named_selects 998 if isinstance(expression, exp.Subquery) and isinstance(expression.this, exp.SetOperation): 999 # Different types of SET modifiers can be chained together if they're explicitly grouped by nesting 1000 return self.get_source_columns_from_set_op(expression.this) 1001 if not isinstance(expression, exp.SetOperation): 1002 raise OptimizeError(f"Unknown set operation: {expression}") 1003 1004 set_op = expression 1005 1006 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 1007 on_column_list = set_op.args.get("on") 1008 1009 if on_column_list: 1010 # The resulting columns are the columns in the ON clause: 1011 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 1012 columns = [col.name for col in on_column_list] 1013 elif set_op.side or set_op.kind: 1014 side = set_op.side 1015 kind = set_op.kind 1016 1017 # Visit the children UNIONs (if any) in a post-order traversal 1018 left = self.get_source_columns_from_set_op(set_op.left) 1019 right = self.get_source_columns_from_set_op(set_op.right) 1020 1021 # We use dict.fromkeys to deduplicate keys and maintain insertion order 1022 if side == "LEFT": 1023 columns = left 1024 elif side == "FULL": 1025 columns = list(dict.fromkeys(left + right)) 1026 elif kind == "INNER": 1027 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 1028 else: 1029 columns = set_op.named_selects 1030 1031 return columns
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
1033 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 1034 """Resolve the source columns for a given source `name`.""" 1035 cache_key = (name, only_visible) 1036 if cache_key not in self._get_source_columns_cache: 1037 if name not in self.scope.sources: 1038 raise OptimizeError(f"Unknown table: {name}") 1039 1040 source = self.scope.sources[name] 1041 1042 if isinstance(source, exp.Table): 1043 columns = self.schema.column_names(source, only_visible) 1044 elif isinstance(source, Scope) and isinstance( 1045 source.expression, (exp.Values, exp.Unnest) 1046 ): 1047 columns = source.expression.named_selects 1048 1049 # in bigquery, unnest structs are automatically scoped as tables, so you can 1050 # directly select a struct field in a query. 1051 # this handles the case where the unnest is statically defined. 1052 if self.schema.dialect == "bigquery": 1053 if source.expression.is_type(exp.DataType.Type.STRUCT): 1054 for k in source.expression.type.expressions: # type: ignore 1055 columns.append(k.name) 1056 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 1057 columns = self.get_source_columns_from_set_op(source.expression) 1058 1059 else: 1060 select = seq_get(source.expression.selects, 0) 1061 1062 if isinstance(select, exp.QueryTransform): 1063 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 1064 schema = select.args.get("schema") 1065 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 1066 else: 1067 columns = source.expression.named_selects 1068 1069 node, _ = self.scope.selected_sources.get(name) or (None, None) 1070 if isinstance(node, Scope): 1071 column_aliases = node.expression.alias_column_names 1072 elif isinstance(node, exp.Expression): 1073 column_aliases = node.alias_column_names 1074 else: 1075 column_aliases = [] 1076 1077 if column_aliases: 1078 # If the source's columns are aliased, their aliases shadow the corresponding column names. 1079 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 1080 columns = [ 1081 alias or name 1082 for (name, alias) in itertools.zip_longest(columns, column_aliases) 1083 ] 1084 1085 self._get_source_columns_cache[cache_key] = columns 1086 1087 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.