sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError, highlight_sql 9from sqlglot.helper import seq_get 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.resolver import Resolver 12from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 13from sqlglot.optimizer.simplify import simplify_parens 14from sqlglot.schema import Schema, ensure_schema 15 16if t.TYPE_CHECKING: 17 from sqlglot._typing import E 18 from collections.abc import Iterator, Iterable 19 20 21def qualify_columns( 22 expression: exp.Expr, 23 schema: dict[str, object] | Schema, 24 expand_alias_refs: bool = True, 25 expand_stars: bool = True, 26 infer_schema: bool | None = None, 27 allow_partial_qualification: bool = False, 28 dialect: DialectType = None, 29) -> exp.Expr: 30 """ 31 Rewrite sqlglot AST to have fully qualified columns. 32 33 Example: 34 >>> import sqlglot 35 >>> schema = {"tbl": {"col": "INT"}} 36 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 37 >>> qualify_columns(expression, schema).sql() 38 'SELECT tbl.col AS col FROM tbl' 39 40 Args: 41 expression: Expr to qualify. 42 schema: Database schema. 43 expand_alias_refs: Whether to expand references to aliases. 44 expand_stars: Whether to expand star queries. This is a necessary step 45 for most of the optimizer's rules to work; do not set to False unless you 46 know what you're doing! 47 infer_schema: Whether to infer the schema if missing. 48 allow_partial_qualification: Whether to allow partial qualification. 49 50 Returns: 51 The qualified expression. 52 53 Notes: 54 - Currently only handles a single PIVOT or UNPIVOT operator 55 """ 56 schema = ensure_schema(schema, dialect=dialect) 57 annotator = TypeAnnotator(schema) 58 infer_schema = schema.empty if infer_schema is None else infer_schema 59 dialect = schema.dialect or Dialect() 60 pseudocolumns = dialect.PSEUDOCOLUMNS 61 62 for scope in traverse_scope(expression): 63 if dialect.PREFER_CTE_ALIAS_COLUMN: 64 pushdown_cte_alias_columns(scope) 65 66 scope_expression = scope.expression 67 is_select = isinstance(scope_expression, exp.Select) 68 69 _separate_pseudocolumns(scope, pseudocolumns) 70 71 resolver = Resolver(scope, schema, infer_schema=infer_schema) 72 _pop_table_column_aliases(scope.ctes) 73 _pop_table_column_aliases(scope.derived_tables) 74 using_column_tables = _expand_using(scope, resolver) 75 76 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 77 _expand_alias_refs( 78 scope, 79 resolver, 80 dialect, 81 expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF, 82 ) 83 84 _convert_columns_to_dots(scope, resolver) 85 _qualify_columns( 86 scope, 87 resolver, 88 allow_partial_qualification=allow_partial_qualification, 89 ) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if dialect.ANNOTATE_ALL_SCOPES: 112 annotator.annotate_scope(scope) 113 114 return expression 115 116 117def validate_qualify_columns(expression: E, sql: str | None = None) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 line = column.this.meta.get("line") 128 col = column.this.meta.get("col") 129 start = column.this.meta.get("start") 130 end = column.this.meta.get("end") 131 132 error_msg = f"Column '{column.name}' could not be resolved{for_table}." 133 if line and col: 134 error_msg += f" Line: {line}, Col: {col}" 135 if sql and start is not None and end is not None: 136 formatted_sql = highlight_sql(sql, [(start, end)])[0] 137 error_msg += f"\n {formatted_sql}" 138 139 raise OptimizeError(error_msg) 140 141 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 142 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 143 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 144 # this list here to ensure those in the former category will be excluded. 145 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 146 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 147 148 all_unqualified_columns.extend(unqualified_columns) 149 150 if all_unqualified_columns: 151 first_column = all_unqualified_columns[0] 152 line = first_column.this.meta.get("line") 153 col = first_column.this.meta.get("col") 154 start = first_column.this.meta.get("start") 155 end = first_column.this.meta.get("end") 156 157 error_msg = f"Ambiguous column '{first_column.name}'" 158 if line and col: 159 error_msg += f" (Line: {line}, Col: {col})" 160 if sql and start is not None and end is not None: 161 formatted_sql = highlight_sql(sql, [(start, end)])[0] 162 error_msg += f"\n {formatted_sql}" 163 164 raise OptimizeError(error_msg) 165 166 return expression 167 168 169def _separate_pseudocolumns(scope: Scope, pseudocolumns: set[str]) -> None: 170 if not pseudocolumns: 171 return 172 173 has_pseudocolumns = False 174 scope_expression = scope.expression 175 176 for column in scope.columns: 177 name = column.name.upper() 178 if name not in pseudocolumns: 179 continue 180 181 if name != "LEVEL" or ( 182 isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect") 183 ): 184 column.replace(exp.Pseudocolumn(**column.args)) 185 has_pseudocolumns = True 186 187 if has_pseudocolumns: 188 scope.clear_cache() 189 190 191def _unpivot_columns(unpivot: exp.Pivot) -> Iterator[exp.Column]: 192 name_columns = [ 193 field.this 194 for field in unpivot.fields 195 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 196 ] 197 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 198 199 return itertools.chain(name_columns, value_columns) 200 201 202def _pop_table_column_aliases(derived_tables: Iterable[exp.Expr]) -> None: 203 """ 204 Remove table column aliases. 205 206 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 207 """ 208 for derived_table in derived_tables: 209 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 210 continue 211 table_alias = derived_table.args.get("alias") 212 if table_alias: 213 table_alias.set("columns", None) 214 215 216def _expand_using(scope: Scope, resolver: Resolver) -> dict[str, t.Any]: 217 columns = {} 218 219 def _update_source_columns(source_name: str) -> None: 220 for column_name in resolver.get_source_columns(source_name): 221 if column_name not in columns: 222 columns[column_name] = source_name 223 224 joins = list(scope.find_all(exp.Join)) 225 names = {join.alias_or_name for join in joins} 226 ordered = [key for key in scope.selected_sources if key not in names] 227 228 if names and not ordered: 229 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 230 231 # Mapping of automatically joined column names to an ordered set of source names (dict). 232 column_tables: dict[str, dict[str, t.Any]] = {} 233 234 for source_name in ordered: 235 _update_source_columns(source_name) 236 237 for i, join in enumerate(joins): 238 source_table = ordered[-1] 239 if source_table: 240 _update_source_columns(source_table) 241 242 join_table = join.alias_or_name 243 ordered.append(join_table) 244 245 using = join.args.get("using") 246 if not using: 247 continue 248 249 join_columns = resolver.get_source_columns(join_table) 250 conditions = [] 251 using_identifier_count = len(using) 252 is_semi_or_anti_join = join.is_semi_or_anti_join 253 254 for identifier in using: 255 identifier = identifier.name 256 table = columns.get(identifier) 257 258 if not table or identifier not in join_columns: 259 if (columns and "*" not in columns) and join_columns: 260 raise OptimizeError(f"Cannot automatically join: {identifier}") 261 262 table = table or source_table 263 264 if i == 0 or using_identifier_count == 1: 265 lhs: exp.Expr = exp.column(identifier, table=table) 266 else: 267 coalesce_columns = [ 268 exp.column(identifier, table=t) 269 for t in ordered[:-1] 270 if identifier in resolver.get_source_columns(t) 271 ] 272 if len(coalesce_columns) > 1: 273 lhs = exp.func("coalesce", *coalesce_columns) 274 else: 275 lhs = exp.column(identifier, table=table) 276 277 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 278 279 # Set all values in the dict to None, because we only care about the key ordering 280 tables = column_tables.setdefault(identifier, {}) 281 282 # Do not update the dict if this was a SEMI/ANTI join in 283 # order to avoid generating COALESCE columns for this join pair 284 if not is_semi_or_anti_join: 285 if table not in tables: 286 tables[table] = None 287 if join_table not in tables: 288 tables[join_table] = None 289 290 join.set("using", None) 291 join.set("on", exp.and_(*conditions, copy=False)) 292 293 if column_tables: 294 for column in scope.columns: 295 if not column.table and column.name in column_tables: 296 tables = column_tables[column.name] 297 coalesce_args = [exp.column(column.name, table=table) for table in tables] 298 replacement: exp.Expr = exp.func("coalesce", *coalesce_args) 299 300 if isinstance(column.parent, exp.Select): 301 # Ensure the USING column keeps its name if it's projected 302 replacement = alias(replacement, alias=column.name, copy=False) 303 elif isinstance(column.parent, exp.Struct): 304 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 305 replacement = exp.PropertyEQ( 306 this=exp.to_identifier(column.name), expression=replacement 307 ) 308 309 scope.replace(column, replacement) 310 311 return column_tables 312 313 314def _expand_alias_refs( 315 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 316) -> None: 317 """ 318 Expand references to aliases. 319 Example: 320 SELECT y.foo AS bar, bar * 2 AS baz FROM y 321 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 322 """ 323 expression = scope.expression 324 325 if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION: 326 return 327 328 alias_to_expression: dict[str, tuple[exp.Expr, int]] = {} 329 projections = {s.alias_or_name for s in expression.selects} 330 replaced = False 331 332 def replace_columns( 333 node: exp.Expr | None, resolve_table: bool = False, literal_index: bool = False 334 ) -> None: 335 nonlocal replaced 336 is_group_by = isinstance(node, exp.Group) 337 is_having = isinstance(node, exp.Having) 338 if not node or (expand_only_groupby and not is_group_by): 339 return 340 341 for column in walk_in_scope(node, prune=lambda node: node.is_star): 342 if not isinstance(column, exp.Column): 343 continue 344 345 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 346 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 347 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 348 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 349 if expand_only_groupby and is_group_by and column.parent is not node: 350 continue 351 352 skip_replace = False 353 table = resolver.get_table(column.name) if resolve_table and not column.table else None 354 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 355 356 if alias_expr: 357 skip_replace = bool( 358 alias_expr.find(exp.AggFunc) 359 and column.find_ancestor(exp.AggFunc) 360 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 361 ) 362 363 # BigQuery's having clause gets confused if an alias matches a source. 364 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 365 # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b) 366 # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed" 367 if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES: 368 skip_replace = skip_replace or any( 369 node.parts[0].name in projections 370 for node in alias_expr.find_all(exp.Column) 371 ) 372 elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having): 373 column_table = table.name if table else column.table 374 if column_table in projections: 375 # BigQuery's GROUP BY and HAVING clauses get confused if the column name 376 # matches a source name and a projection. For instance: 377 # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1 378 # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table 379 # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause" 380 column.replace(exp.to_identifier(column.name)) 381 replaced = True 382 return 383 384 if table and (not alias_expr or skip_replace): 385 column.set("table", table) 386 elif not column.table and alias_expr and not skip_replace: 387 if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and ( 388 literal_index or resolve_table 389 ): 390 if literal_index: 391 column.replace(exp.Literal.number(i)) 392 replaced = True 393 else: 394 replaced = True 395 column = column.replace(exp.paren(alias_expr)) 396 simplified = simplify_parens(column, dialect) 397 if simplified is not column: 398 column.replace(simplified) 399 400 for i, projection in enumerate(expression.selects): 401 replace_columns(projection) 402 if isinstance(projection, exp.Alias): 403 alias_to_expression[projection.alias] = (projection.this, i + 1) 404 405 parent_scope: Scope | None = scope 406 on_right_sub_tree = False 407 while parent_scope and not parent_scope.is_cte: 408 if parent_scope := parent_scope.parent: 409 if isinstance(parent_scope.expression, exp.Union): 410 on_right_sub_tree = parent_scope.expression.right is parent_scope.expression 411 412 # We shouldn't expand aliases if they match the recursive CTE's columns 413 # and we are in the recursive part (right sub tree) of the CTE 414 if parent_scope and on_right_sub_tree: 415 if cte := parent_scope.expression.parent: 416 with_ = cte.find_ancestor(exp.With) 417 if with_ and with_.recursive: 418 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 419 alias_to_expression.pop(recursive_cte_column.output_name, None) 420 421 replace_columns(expression.args.get("where")) 422 replace_columns(expression.args.get("group"), literal_index=True) 423 replace_columns(expression.args.get("having"), resolve_table=True) 424 replace_columns(expression.args.get("qualify"), resolve_table=True) 425 426 if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS: 427 for join in expression.args.get("joins") or []: 428 replace_columns(join) 429 430 if replaced: 431 scope.clear_cache() 432 433 434def _expand_group_by(scope: Scope, dialect: Dialect) -> None: 435 expression = scope.expression 436 group = expression.args.get("group") 437 if not group: 438 return 439 440 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 441 expression.set("group", group) 442 443 444def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 445 expression = scope.expression 446 447 if not isinstance(expression, exp.Selectable): 448 return 449 450 # TODO (mypyc): rebind to exp.Expr to avoid Selectable trait vtable dispatch for .args 451 expr: exp.Expr = expression 452 453 for modifier_key in ("order", "distinct"): 454 modifier = expr.args.get(modifier_key) 455 if isinstance(modifier, exp.Distinct): 456 modifier = modifier.args.get("on") 457 458 if not isinstance(modifier, exp.Expr): 459 continue 460 461 modifier_expressions = modifier.expressions 462 if modifier_key == "order": 463 modifier_expressions = [ordered.this for ordered in modifier_expressions] 464 465 for original, expanded in zip( 466 modifier_expressions, 467 _expand_positional_references( 468 scope, modifier_expressions, resolver.dialect, alias=True 469 ), 470 ): 471 for agg in original.find_all(exp.AggFunc): 472 for col in agg.find_all(exp.Column): 473 if not col.table: 474 col.set("table", resolver.get_table(col.name)) 475 476 original.replace(expanded) 477 478 if expr.args.get("group"): 479 selects = {s.this: exp.column(s.alias_or_name) for s in expression.selects} 480 481 for node in modifier_expressions: 482 node.replace( 483 exp.to_identifier(_select_by_pos(expression, node).alias) 484 if node.is_int 485 else selects.get(node, node) 486 ) 487 488 489def _expand_positional_references( 490 scope: Scope, expressions: Iterable[exp.Expr], dialect: Dialect, alias: bool = False 491) -> list[exp.Expr]: 492 new_nodes: list[exp.Expr] = [] 493 ambiguous_projections = None 494 495 expression = scope.expression 496 497 if not isinstance(expression, exp.Selectable): 498 return new_nodes 499 500 for node in expressions: 501 if node.is_int and isinstance(node, exp.Literal): 502 select = _select_by_pos(expression, node) 503 504 if alias: 505 new_nodes.append(exp.column(select.args["alias"].copy())) 506 else: 507 # TODO (mypyc): use a separate variable to avoid reusing `select` (Alias) with a different type 508 select_expr: exp.Expr = select.this 509 510 if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES: 511 if ambiguous_projections is None: 512 # When a projection name is also a source name and it is referenced in the 513 # GROUP BY clause, BQ can't understand what the identifier corresponds to 514 ambiguous_projections = { 515 s.alias_or_name 516 for s in expression.selects 517 if s.alias_or_name in scope.selected_sources 518 } 519 520 ambiguous = any( 521 column.parts[0].name in ambiguous_projections 522 for column in select_expr.find_all(exp.Column) 523 ) 524 else: 525 ambiguous = False 526 527 if ( 528 isinstance(select_expr, exp.CONSTANTS) 529 or select_expr.is_number 530 or select_expr.find(exp.Explode, exp.Unnest) 531 or ambiguous 532 ): 533 new_nodes.append(node) 534 else: 535 new_nodes.append(select_expr.copy()) 536 else: 537 new_nodes.append(node) 538 539 return new_nodes 540 541 542def _select_by_pos(expression: exp.Selectable, node: exp.Literal) -> exp.Alias: 543 try: 544 return expression.selects[int(node.this) - 1].assert_is(exp.Alias) 545 except IndexError: 546 raise OptimizeError(f"Unknown output column: {node.name}") 547 548 549def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 550 """ 551 Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`. 552 553 These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be 554 normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly. 555 """ 556 converted = False 557 for column in itertools.chain(scope.columns, scope.stars): 558 if isinstance(column, exp.Dot): 559 continue 560 561 column_table: str | exp.Identifier | None = column.table 562 dot_parts = column.meta.pop("dot_parts", []) 563 if ( 564 column_table 565 and column_table not in scope.sources 566 and ( 567 not scope.parent 568 or column_table not in scope.parent.sources 569 or not scope.is_correlated_subquery 570 ) 571 ): 572 root, *parts = column.parts 573 574 if isinstance(root, exp.Identifier) and root.name in scope.sources: 575 # The struct is already qualified, but we still need to change the AST 576 column_table = root 577 root, *parts = parts 578 was_qualified = True 579 else: 580 column_table = resolver.get_table(root.name) 581 was_qualified = False 582 583 if column_table: 584 converted = True 585 new_column = exp.column(root, table=column_table) 586 587 if dot_parts: 588 # Remove the actual column parts from the rest of dot parts 589 new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :] 590 591 column.replace(exp.Dot.build([new_column, *parts])) 592 593 if converted: 594 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 595 # a `for column in scope.columns` iteration, even though they shouldn't be 596 scope.clear_cache() 597 598 599def _qualify_columns( 600 scope: Scope, 601 resolver: Resolver, 602 allow_partial_qualification: bool, 603) -> None: 604 """Disambiguate columns, ensuring each column specifies a source""" 605 for column in scope.columns: 606 column_table = column.table 607 column_name = column.name 608 609 if column_table and column_table in scope.sources: 610 source_columns = resolver.get_source_columns(column_table) 611 if ( 612 not allow_partial_qualification 613 and source_columns 614 and column_name not in source_columns 615 and "*" not in source_columns 616 ): 617 raise OptimizeError(f"Unknown column: {column_name}") 618 619 if not column_table: 620 if scope.pivots and not column.find_ancestor(exp.Pivot): 621 # If the column is under the Pivot expression, we need to qualify it 622 # using the name of the pivoted source instead of the pivot's alias 623 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 624 continue 625 626 # column_table can be a '' because bigquery unnest has no table alias 627 table = resolver.get_table(column) 628 629 if ( 630 table 631 and isinstance(source := scope.sources.get(table.name), Scope) 632 and id(column) in source.column_index 633 ): 634 continue 635 636 if table: 637 column.set("table", table) 638 elif ( 639 resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS 640 and len(column.parts) == 1 641 and column_name in scope.selected_sources 642 ): 643 # BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records 644 scope.replace(column, exp.TableColumn(this=column.this)) 645 646 for pivot in scope.pivots: 647 for column in pivot.find_all(exp.Column): 648 if not column.table and column.name in resolver.all_columns: 649 table = resolver.get_table(column.name) 650 if table: 651 column.set("table", table) 652 653 654def _expand_struct_stars_no_parens( 655 expression: exp.Dot, 656) -> list[exp.Alias]: 657 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 658 659 dot_column = expression.find(exp.Column) 660 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DType.STRUCT): 661 return [] 662 663 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 664 dot_column = dot_column.copy() 665 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 666 667 # First part is the table name and last part is the star so they can be dropped 668 dot_parts = expression.parts[1:-1] 669 670 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 671 for part in dot_parts[1:]: 672 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 673 # Unable to expand star unless all fields are named 674 if not isinstance(field.this, exp.Identifier): 675 return [] 676 677 if field.name == part.name and field.kind.is_type(exp.DType.STRUCT): 678 starting_struct = field 679 break 680 else: 681 # There is no matching field in the struct 682 return [] 683 684 taken_names = set() 685 new_selections = [] 686 687 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 688 name = field.name 689 690 # Ambiguous or anonymous fields can't be expanded 691 if name in taken_names or not isinstance(field.this, exp.Identifier): 692 return [] 693 694 taken_names.add(name) 695 696 this = field.this.copy() 697 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 698 new_column = exp.column( 699 t.cast(exp.Identifier, root), 700 table=dot_column.args.get("table"), 701 fields=t.cast(list[exp.Identifier], parts), 702 ) 703 new_selections.append(alias(new_column, this, copy=False).assert_is(exp.Alias)) 704 705 return new_selections 706 707 708def _expand_struct_stars_with_parens(expression: exp.Dot) -> list[exp.Alias]: 709 """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column""" 710 711 # it is not (<sub_exp>).* pattern, which means we can't expand 712 if not isinstance(expression.this, exp.Paren): 713 return [] 714 715 # find column definition to get data-type 716 dot_column = expression.find(exp.Column) 717 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DType.STRUCT): 718 return [] 719 720 parent = dot_column.parent 721 starting_struct = dot_column.type 722 723 # walk up AST and down into struct definition in sync 724 while parent is not None: 725 if isinstance(parent, exp.Paren): 726 parent = parent.parent 727 continue 728 729 # if parent is not a dot, then something is wrong 730 if not isinstance(parent, exp.Dot): 731 return [] 732 733 # if the rhs of the dot is star we are done 734 rhs = parent.right 735 if isinstance(rhs, exp.Star): 736 break 737 738 # if it is not identifier, then something is wrong 739 if not isinstance(rhs, exp.Identifier): 740 return [] 741 742 # Check if current rhs identifier is in struct 743 matched = False 744 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 745 if struct_field_def.name == rhs.name: 746 matched = True 747 starting_struct = struct_field_def.kind # update struct 748 break 749 750 if not matched: 751 return [] 752 753 parent = parent.parent 754 755 # build new aliases to expand star 756 new_selections = [] 757 758 # fetch the outermost parentheses for new aliaes 759 outer_paren = expression.this 760 761 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 762 new_identifier = struct_field_def.this.copy() 763 new_dot = exp.Dot.build([outer_paren.copy(), new_identifier]) 764 new_alias = alias(new_dot, new_identifier, copy=False).assert_is(exp.Alias) 765 new_selections.append(new_alias) 766 767 return new_selections 768 769 770def _expand_stars( 771 scope: Scope, 772 resolver: Resolver, 773 using_column_tables: dict[str, t.Any], 774 pseudocolumns: set[str], 775 annotator: TypeAnnotator, 776) -> None: 777 """Expand stars to lists of column selections""" 778 779 new_selections: list[exp.Expr] = [] 780 except_columns: dict[int, set[str]] = {} 781 replace_columns: dict[int, dict[str, exp.Alias]] = {} 782 rename_columns: dict[int, dict[str, str]] = {} 783 784 coalesced_columns = set() 785 dialect = resolver.dialect 786 787 pivot_output_columns = None 788 pivot_exclude_columns: set[str] = set() 789 790 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 791 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 792 if pivot.unpivot: 793 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 794 795 for field in pivot.fields: 796 if isinstance(field, exp.In): 797 pivot_exclude_columns.update( 798 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 799 ) 800 801 else: 802 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 803 804 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 805 if not pivot_output_columns: 806 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 807 808 if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any( 809 isinstance(col, exp.Dot) for col in scope.stars 810 ): 811 # Found struct expansion, annotate scope ahead of time 812 annotator.annotate_scope(scope) 813 814 scope_expression = scope.expression 815 816 if not isinstance(scope_expression, exp.Selectable): 817 return 818 819 for expression in scope_expression.selects: 820 tables: list[str] = [] 821 if isinstance(expression, exp.Star): 822 tables.extend(scope.selected_sources) 823 _add_except_columns(expression, tables, except_columns) 824 _add_replace_columns(expression, tables, replace_columns) 825 _add_rename_columns(expression, tables, rename_columns) 826 elif expression.is_star: 827 if isinstance(expression, exp.Column): 828 tables.append(expression.table) 829 _add_except_columns(expression.this, tables, except_columns) 830 _add_replace_columns(expression.this, tables, replace_columns) 831 _add_rename_columns(expression.this, tables, rename_columns) 832 elif isinstance(expression, exp.Dot): 833 if ( 834 dialect.SUPPORTS_STRUCT_STAR_EXPANSION 835 and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS 836 ): 837 struct_fields = _expand_struct_stars_no_parens(expression) 838 if struct_fields: 839 new_selections.extend(struct_fields) 840 continue 841 elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS: 842 struct_fields = _expand_struct_stars_with_parens(expression) 843 if struct_fields: 844 new_selections.extend(struct_fields) 845 continue 846 847 if not tables: 848 new_selections.append(expression) 849 continue 850 851 for table in tables: 852 if table not in scope.sources: 853 raise OptimizeError(f"Unknown table: {table}") 854 855 columns = resolver.get_source_columns(table, only_visible=True) 856 columns = columns or scope.outer_columns 857 858 if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR: 859 columns = [name for name in columns if name.upper() not in pseudocolumns] 860 861 if not columns or "*" in columns: 862 return 863 864 table_id = id(table) 865 columns_to_exclude = except_columns.get(table_id) or set() 866 renamed_columns = rename_columns.get(table_id, {}) 867 replaced_columns = replace_columns.get(table_id, {}) 868 869 if pivot: 870 if pivot_output_columns and pivot_exclude_columns: 871 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 872 pivot_columns.extend(pivot_output_columns) 873 else: 874 pivot_columns = pivot.alias_column_names 875 876 if pivot_columns: 877 new_selections.extend( 878 alias(exp.column(name, table=pivot.alias), name, copy=False) 879 for name in pivot_columns 880 if name not in columns_to_exclude 881 ) 882 continue 883 884 for name in columns: 885 if name in columns_to_exclude or name in coalesced_columns: 886 continue 887 if name in using_column_tables and table in using_column_tables[name]: 888 coalesced_columns.add(name) 889 # TODO (mypyc): use a separate variable to avoid reusing `tables` (list) with dict type 890 using_tables = using_column_tables[name] 891 coalesce_args = [exp.column(name, table=table) for table in using_tables] 892 893 new_selections.append( 894 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 895 ) 896 else: 897 alias_ = renamed_columns.get(name, name) 898 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 899 new_selections.append( 900 alias(selection_expr, alias_, copy=False) 901 if alias_ != name 902 else selection_expr 903 ) 904 905 # Ensures we don't overwrite the initial selections with an empty list 906 if new_selections and isinstance(scope_expression, exp.Select): 907 scope_expression.set("expressions", new_selections) 908 909 910def _add_except_columns(expression: exp.Expr, tables, except_columns: dict[int, set[str]]) -> None: 911 except_ = expression.args.get("except_") 912 913 if not except_: 914 return 915 916 columns = {e.name for e in except_} 917 918 for table in tables: 919 except_columns[id(table)] = columns 920 921 922def _add_rename_columns( 923 expression: exp.Expr, tables, rename_columns: dict[int, dict[str, str]] 924) -> None: 925 rename = expression.args.get("rename") 926 927 if not rename: 928 return 929 930 columns = {e.this.name: e.alias for e in rename} 931 932 for table in tables: 933 rename_columns[id(table)] = columns 934 935 936def _add_replace_columns( 937 expression: exp.Expr, tables, replace_columns: dict[int, dict[str, exp.Alias]] 938) -> None: 939 replace = expression.args.get("replace") 940 941 if not replace: 942 return 943 944 columns = {e.alias: e for e in replace} 945 946 for table in tables: 947 replace_columns[id(table)] = columns 948 949 950def qualify_outputs(scope_or_expression: Scope | exp.Expr) -> None: 951 """Ensure all output columns are aliased""" 952 if isinstance(scope_or_expression, exp.Expr): 953 scope = build_scope(scope_or_expression) 954 if not isinstance(scope, Scope): 955 return 956 else: 957 scope = scope_or_expression 958 959 expression = scope.expression 960 961 if not isinstance(expression, exp.Selectable): 962 return 963 964 new_selections = [] 965 966 for i, (selection, aliased_column) in enumerate( 967 itertools.zip_longest(expression.selects, scope.outer_columns) 968 ): 969 if selection is None or isinstance(selection, exp.QueryTransform): 970 break 971 972 if isinstance(selection, exp.Subquery): 973 if not selection.output_name: 974 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 975 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 976 selection = alias( 977 selection, 978 alias=selection.output_name or f"_col_{i}", 979 copy=False, 980 ) 981 if aliased_column: 982 selection.set("alias", exp.to_identifier(aliased_column)) 983 984 new_selections.append(selection) 985 986 if new_selections and isinstance(expression, exp.Select): 987 expression.set("expressions", new_selections) 988 989 990def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 991 """Makes sure all identifiers that need to be quoted are quoted.""" 992 return expression.transform( 993 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 994 ) # type: ignore 995 996 997def pushdown_cte_alias_columns(scope: Scope) -> None: 998 """ 999 Pushes down the CTE alias columns into the projection, 1000 1001 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 1002 1003 Args: 1004 scope: Scope to find ctes to pushdown aliases. 1005 """ 1006 for cte in scope.ctes: 1007 if cte.alias_column_names and isinstance(cte.this, exp.Select): 1008 new_expressions = [] 1009 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 1010 if isinstance(projection, exp.Alias): 1011 projection.set("alias", exp.to_identifier(_alias)) 1012 else: 1013 projection = alias(projection, alias=_alias) 1014 new_expressions.append(projection) 1015 cte.this.set("expressions", new_expressions)
def
qualify_columns( expression: sqlglot.expressions.core.Expr, schema: dict[str, object] | sqlglot.schema.Schema, expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: bool | None = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.core.Expr:
22def qualify_columns( 23 expression: exp.Expr, 24 schema: dict[str, object] | Schema, 25 expand_alias_refs: bool = True, 26 expand_stars: bool = True, 27 infer_schema: bool | None = None, 28 allow_partial_qualification: bool = False, 29 dialect: DialectType = None, 30) -> exp.Expr: 31 """ 32 Rewrite sqlglot AST to have fully qualified columns. 33 34 Example: 35 >>> import sqlglot 36 >>> schema = {"tbl": {"col": "INT"}} 37 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 38 >>> qualify_columns(expression, schema).sql() 39 'SELECT tbl.col AS col FROM tbl' 40 41 Args: 42 expression: Expr to qualify. 43 schema: Database schema. 44 expand_alias_refs: Whether to expand references to aliases. 45 expand_stars: Whether to expand star queries. This is a necessary step 46 for most of the optimizer's rules to work; do not set to False unless you 47 know what you're doing! 48 infer_schema: Whether to infer the schema if missing. 49 allow_partial_qualification: Whether to allow partial qualification. 50 51 Returns: 52 The qualified expression. 53 54 Notes: 55 - Currently only handles a single PIVOT or UNPIVOT operator 56 """ 57 schema = ensure_schema(schema, dialect=dialect) 58 annotator = TypeAnnotator(schema) 59 infer_schema = schema.empty if infer_schema is None else infer_schema 60 dialect = schema.dialect or Dialect() 61 pseudocolumns = dialect.PSEUDOCOLUMNS 62 63 for scope in traverse_scope(expression): 64 if dialect.PREFER_CTE_ALIAS_COLUMN: 65 pushdown_cte_alias_columns(scope) 66 67 scope_expression = scope.expression 68 is_select = isinstance(scope_expression, exp.Select) 69 70 _separate_pseudocolumns(scope, pseudocolumns) 71 72 resolver = Resolver(scope, schema, infer_schema=infer_schema) 73 _pop_table_column_aliases(scope.ctes) 74 _pop_table_column_aliases(scope.derived_tables) 75 using_column_tables = _expand_using(scope, resolver) 76 77 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 78 _expand_alias_refs( 79 scope, 80 resolver, 81 dialect, 82 expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF, 83 ) 84 85 _convert_columns_to_dots(scope, resolver) 86 _qualify_columns( 87 scope, 88 resolver, 89 allow_partial_qualification=allow_partial_qualification, 90 ) 91 92 if not schema.empty and expand_alias_refs: 93 _expand_alias_refs(scope, resolver, dialect) 94 95 if is_select: 96 if expand_stars: 97 _expand_stars( 98 scope, 99 resolver, 100 using_column_tables, 101 pseudocolumns, 102 annotator, 103 ) 104 qualify_outputs(scope) 105 106 _expand_group_by(scope, dialect) 107 108 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 109 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 110 _expand_order_by_and_distinct_on(scope, resolver) 111 112 if dialect.ANNOTATE_ALL_SCOPES: 113 annotator.annotate_scope(scope) 114 115 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expr to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E, sql: str | None = None) -> ~E:
118def validate_qualify_columns(expression: E, sql: str | None = None) -> E: 119 """Raise an `OptimizeError` if any columns aren't qualified""" 120 all_unqualified_columns = [] 121 for scope in traverse_scope(expression): 122 if isinstance(scope.expression, exp.Select): 123 unqualified_columns = scope.unqualified_columns 124 125 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 126 column = scope.external_columns[0] 127 for_table = f" for table: '{column.table}'" if column.table else "" 128 line = column.this.meta.get("line") 129 col = column.this.meta.get("col") 130 start = column.this.meta.get("start") 131 end = column.this.meta.get("end") 132 133 error_msg = f"Column '{column.name}' could not be resolved{for_table}." 134 if line and col: 135 error_msg += f" Line: {line}, Col: {col}" 136 if sql and start is not None and end is not None: 137 formatted_sql = highlight_sql(sql, [(start, end)])[0] 138 error_msg += f"\n {formatted_sql}" 139 140 raise OptimizeError(error_msg) 141 142 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 143 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 144 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 145 # this list here to ensure those in the former category will be excluded. 146 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 147 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 148 149 all_unqualified_columns.extend(unqualified_columns) 150 151 if all_unqualified_columns: 152 first_column = all_unqualified_columns[0] 153 line = first_column.this.meta.get("line") 154 col = first_column.this.meta.get("col") 155 start = first_column.this.meta.get("start") 156 end = first_column.this.meta.get("end") 157 158 error_msg = f"Ambiguous column '{first_column.name}'" 159 if line and col: 160 error_msg += f" (Line: {line}, Col: {col})" 161 if sql and start is not None and end is not None: 162 formatted_sql = highlight_sql(sql, [(start, end)])[0] 163 error_msg += f"\n {formatted_sql}" 164 165 raise OptimizeError(error_msg) 166 167 return expression
Raise an OptimizeError if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.core.Expr) -> None:
951def qualify_outputs(scope_or_expression: Scope | exp.Expr) -> None: 952 """Ensure all output columns are aliased""" 953 if isinstance(scope_or_expression, exp.Expr): 954 scope = build_scope(scope_or_expression) 955 if not isinstance(scope, Scope): 956 return 957 else: 958 scope = scope_or_expression 959 960 expression = scope.expression 961 962 if not isinstance(expression, exp.Selectable): 963 return 964 965 new_selections = [] 966 967 for i, (selection, aliased_column) in enumerate( 968 itertools.zip_longest(expression.selects, scope.outer_columns) 969 ): 970 if selection is None or isinstance(selection, exp.QueryTransform): 971 break 972 973 if isinstance(selection, exp.Subquery): 974 if not selection.output_name: 975 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 976 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 977 selection = alias( 978 selection, 979 alias=selection.output_name or f"_col_{i}", 980 copy=False, 981 ) 982 if aliased_column: 983 selection.set("alias", exp.to_identifier(aliased_column)) 984 985 new_selections.append(selection) 986 987 if new_selections and isinstance(expression, exp.Select): 988 expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
991def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 992 """Makes sure all identifiers that need to be quoted are quoted.""" 993 return expression.transform( 994 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 995 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
998def pushdown_cte_alias_columns(scope: Scope) -> None: 999 """ 1000 Pushes down the CTE alias columns into the projection, 1001 1002 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 1003 1004 Args: 1005 scope: Scope to find ctes to pushdown aliases. 1006 """ 1007 for cte in scope.ctes: 1008 if cte.alias_column_names and isinstance(cte.this, exp.Select): 1009 new_expressions = [] 1010 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 1011 if isinstance(projection, exp.Alias): 1012 projection.set("alias", exp.to_identifier(_alias)) 1013 else: 1014 projection = alias(projection, alias=_alias) 1015 new_expressions.append(projection) 1016 cte.this.set("expressions", new_expressions)
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Arguments:
- scope: Scope to find ctes to pushdown aliases.