sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError, highlight_sql 9from sqlglot.helper import seq_get 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.resolver import Resolver 12from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 13from sqlglot.optimizer.simplify import simplify_parens 14from sqlglot.schema import Schema, ensure_schema 15 16if t.TYPE_CHECKING: 17 from sqlglot._typing import E 18 19 20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27 dialect: DialectType = None, 28) -> exp.Expression: 29 """ 30 Rewrite sqlglot AST to have fully qualified columns. 31 32 Example: 33 >>> import sqlglot 34 >>> schema = {"tbl": {"col": "INT"}} 35 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 36 >>> qualify_columns(expression, schema).sql() 37 'SELECT tbl.col AS col FROM tbl' 38 39 Args: 40 expression: Expression to qualify. 41 schema: Database schema. 42 expand_alias_refs: Whether to expand references to aliases. 43 expand_stars: Whether to expand star queries. This is a necessary step 44 for most of the optimizer's rules to work; do not set to False unless you 45 know what you're doing! 46 infer_schema: Whether to infer the schema if missing. 47 allow_partial_qualification: Whether to allow partial qualification. 48 49 Returns: 50 The qualified expression. 51 52 Notes: 53 - Currently only handles a single PIVOT or UNPIVOT operator 54 """ 55 schema = ensure_schema(schema, dialect=dialect) 56 annotator = TypeAnnotator(schema) 57 infer_schema = schema.empty if infer_schema is None else infer_schema 58 dialect = schema.dialect or Dialect() 59 pseudocolumns = dialect.PSEUDOCOLUMNS 60 61 for scope in traverse_scope(expression): 62 if dialect.PREFER_CTE_ALIAS_COLUMN: 63 pushdown_cte_alias_columns(scope) 64 65 scope_expression = scope.expression 66 is_select = isinstance(scope_expression, exp.Select) 67 68 _separate_pseudocolumns(scope, pseudocolumns) 69 70 resolver = Resolver(scope, schema, infer_schema=infer_schema) 71 _pop_table_column_aliases(scope.ctes) 72 _pop_table_column_aliases(scope.derived_tables) 73 using_column_tables = _expand_using(scope, resolver) 74 75 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 76 _expand_alias_refs( 77 scope, 78 resolver, 79 dialect, 80 expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF, 81 ) 82 83 _convert_columns_to_dots(scope, resolver) 84 _qualify_columns( 85 scope, 86 resolver, 87 allow_partial_qualification=allow_partial_qualification, 88 ) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if dialect.ANNOTATE_ALL_SCOPES: 111 annotator.annotate_scope(scope) 112 113 return expression 114 115 116def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 line = column.this.meta.get("line") 127 col = column.this.meta.get("col") 128 start = column.this.meta.get("start") 129 end = column.this.meta.get("end") 130 131 error_msg = f"Column '{column.name}' could not be resolved{for_table}." 132 if line and col: 133 error_msg += f" Line: {line}, Col: {col}" 134 if sql and start is not None and end is not None: 135 formatted_sql = highlight_sql(sql, [(start, end)])[0] 136 error_msg += f"\n {formatted_sql}" 137 138 raise OptimizeError(error_msg) 139 140 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 141 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 142 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 143 # this list here to ensure those in the former category will be excluded. 144 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 145 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 146 147 all_unqualified_columns.extend(unqualified_columns) 148 149 if all_unqualified_columns: 150 first_column = all_unqualified_columns[0] 151 line = first_column.this.meta.get("line") 152 col = first_column.this.meta.get("col") 153 start = first_column.this.meta.get("start") 154 end = first_column.this.meta.get("end") 155 156 error_msg = f"Ambiguous column '{first_column.name}'" 157 if line and col: 158 error_msg += f" (Line: {line}, Col: {col})" 159 if sql and start is not None and end is not None: 160 formatted_sql = highlight_sql(sql, [(start, end)])[0] 161 error_msg += f"\n {formatted_sql}" 162 163 raise OptimizeError(error_msg) 164 165 return expression 166 167 168def _separate_pseudocolumns(scope: Scope, pseudocolumns: t.Set[str]) -> None: 169 if not pseudocolumns: 170 return 171 172 has_pseudocolumns = False 173 scope_expression = scope.expression 174 175 for column in scope.columns: 176 name = column.name.upper() 177 if name not in pseudocolumns: 178 continue 179 180 if name != "LEVEL" or ( 181 isinstance(scope_expression, exp.Select) and scope_expression.args.get("connect") 182 ): 183 column.replace(exp.Pseudocolumn(**column.args)) 184 has_pseudocolumns = True 185 186 if has_pseudocolumns: 187 scope.clear_cache() 188 189 190def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 191 name_columns = [ 192 field.this 193 for field in unpivot.fields 194 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 195 ] 196 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 197 198 return itertools.chain(name_columns, value_columns) 199 200 201def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 202 """ 203 Remove table column aliases. 204 205 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 206 """ 207 for derived_table in derived_tables: 208 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 209 continue 210 table_alias = derived_table.args.get("alias") 211 if table_alias: 212 table_alias.set("columns", None) 213 214 215def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 216 columns = {} 217 218 def _update_source_columns(source_name: str) -> None: 219 for column_name in resolver.get_source_columns(source_name): 220 if column_name not in columns: 221 columns[column_name] = source_name 222 223 joins = list(scope.find_all(exp.Join)) 224 names = {join.alias_or_name for join in joins} 225 ordered = [key for key in scope.selected_sources if key not in names] 226 227 if names and not ordered: 228 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 229 230 # Mapping of automatically joined column names to an ordered set of source names (dict). 231 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 232 233 for source_name in ordered: 234 _update_source_columns(source_name) 235 236 for i, join in enumerate(joins): 237 source_table = ordered[-1] 238 if source_table: 239 _update_source_columns(source_table) 240 241 join_table = join.alias_or_name 242 ordered.append(join_table) 243 244 using = join.args.get("using") 245 if not using: 246 continue 247 248 join_columns = resolver.get_source_columns(join_table) 249 conditions = [] 250 using_identifier_count = len(using) 251 is_semi_or_anti_join = join.is_semi_or_anti_join 252 253 for identifier in using: 254 identifier = identifier.name 255 table = columns.get(identifier) 256 257 if not table or identifier not in join_columns: 258 if (columns and "*" not in columns) and join_columns: 259 raise OptimizeError(f"Cannot automatically join: {identifier}") 260 261 table = table or source_table 262 263 if i == 0 or using_identifier_count == 1: 264 lhs: exp.Expression = exp.column(identifier, table=table) 265 else: 266 coalesce_columns = [ 267 exp.column(identifier, table=t) 268 for t in ordered[:-1] 269 if identifier in resolver.get_source_columns(t) 270 ] 271 if len(coalesce_columns) > 1: 272 lhs = exp.func("coalesce", *coalesce_columns) 273 else: 274 lhs = exp.column(identifier, table=table) 275 276 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 277 278 # Set all values in the dict to None, because we only care about the key ordering 279 tables = column_tables.setdefault(identifier, {}) 280 281 # Do not update the dict if this was a SEMI/ANTI join in 282 # order to avoid generating COALESCE columns for this join pair 283 if not is_semi_or_anti_join: 284 if table not in tables: 285 tables[table] = None 286 if join_table not in tables: 287 tables[join_table] = None 288 289 join.set("using", None) 290 join.set("on", exp.and_(*conditions, copy=False)) 291 292 if column_tables: 293 for column in scope.columns: 294 if not column.table and column.name in column_tables: 295 tables = column_tables[column.name] 296 coalesce_args = [exp.column(column.name, table=table) for table in tables] 297 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 298 299 if isinstance(column.parent, exp.Select): 300 # Ensure the USING column keeps its name if it's projected 301 replacement = alias(replacement, alias=column.name, copy=False) 302 elif isinstance(column.parent, exp.Struct): 303 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 304 replacement = exp.PropertyEQ( 305 this=exp.to_identifier(column.name), expression=replacement 306 ) 307 308 scope.replace(column, replacement) 309 310 return column_tables 311 312 313def _expand_alias_refs( 314 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 315) -> None: 316 """ 317 Expand references to aliases. 318 Example: 319 SELECT y.foo AS bar, bar * 2 AS baz FROM y 320 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 321 """ 322 expression = scope.expression 323 324 if not isinstance(expression, exp.Select) or dialect.DISABLES_ALIAS_REF_EXPANSION: 325 return 326 327 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 328 projections = {s.alias_or_name for s in expression.selects} 329 replaced = False 330 331 def replace_columns( 332 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 333 ) -> None: 334 nonlocal replaced 335 is_group_by = isinstance(node, exp.Group) 336 is_having = isinstance(node, exp.Having) 337 if not node or (expand_only_groupby and not is_group_by): 338 return 339 340 for column in walk_in_scope(node, prune=lambda node: node.is_star): 341 if not isinstance(column, exp.Column): 342 continue 343 344 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 345 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 346 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 347 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 348 if expand_only_groupby and is_group_by and column.parent is not node: 349 continue 350 351 skip_replace = False 352 table = resolver.get_table(column.name) if resolve_table and not column.table else None 353 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 354 355 if alias_expr: 356 skip_replace = bool( 357 alias_expr.find(exp.AggFunc) 358 and column.find_ancestor(exp.AggFunc) 359 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 360 ) 361 362 # BigQuery's having clause gets confused if an alias matches a source. 363 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 364 # If "HAVING x" is expanded to "HAVING max(x.b)", BQ would blindly replace the "x" reference with the projection MAX(x.b) 365 # i.e HAVING MAX(MAX(x.b).b), resulting in the error: "Aggregations of aggregations are not allowed" 366 if is_having and dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES: 367 skip_replace = skip_replace or any( 368 node.parts[0].name in projections 369 for node in alias_expr.find_all(exp.Column) 370 ) 371 elif dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES and (is_group_by or is_having): 372 column_table = table.name if table else column.table 373 if column_table in projections: 374 # BigQuery's GROUP BY and HAVING clauses get confused if the column name 375 # matches a source name and a projection. For instance: 376 # SELECT id, ARRAY_AGG(col) AS custom_fields FROM custom_fields GROUP BY id HAVING id >= 1 377 # We should not qualify "id" with "custom_fields" in either clause, since the aggregation shadows the actual table 378 # and we'd get the error: "Column custom_fields contains an aggregation function, which is not allowed in GROUP BY clause" 379 column.replace(exp.to_identifier(column.name)) 380 replaced = True 381 return 382 383 if table and (not alias_expr or skip_replace): 384 column.set("table", table) 385 elif not column.table and alias_expr and not skip_replace: 386 if (isinstance(alias_expr, exp.Literal) or alias_expr.is_number) and ( 387 literal_index or resolve_table 388 ): 389 if literal_index: 390 column.replace(exp.Literal.number(i)) 391 replaced = True 392 else: 393 replaced = True 394 column = column.replace(exp.paren(alias_expr)) 395 simplified = simplify_parens(column, dialect) 396 if simplified is not column: 397 column.replace(simplified) 398 399 for i, projection in enumerate(expression.selects): 400 replace_columns(projection) 401 if isinstance(projection, exp.Alias): 402 alias_to_expression[projection.alias] = (projection.this, i + 1) 403 404 parent_scope = scope 405 on_right_sub_tree = False 406 while parent_scope and not parent_scope.is_cte: 407 if parent_scope.is_union: 408 on_right_sub_tree = parent_scope.parent.expression.right is parent_scope.expression 409 parent_scope = parent_scope.parent 410 411 # We shouldn't expand aliases if they match the recursive CTE's columns 412 # and we are in the recursive part (right sub tree) of the CTE 413 if parent_scope and on_right_sub_tree: 414 cte = parent_scope.expression.parent 415 if cte.find_ancestor(exp.With).recursive: 416 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 417 alias_to_expression.pop(recursive_cte_column.output_name, None) 418 419 replace_columns(expression.args.get("where")) 420 replace_columns(expression.args.get("group"), literal_index=True) 421 replace_columns(expression.args.get("having"), resolve_table=True) 422 replace_columns(expression.args.get("qualify"), resolve_table=True) 423 424 if dialect.SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS: 425 for join in expression.args.get("joins") or []: 426 replace_columns(join) 427 428 if replaced: 429 scope.clear_cache() 430 431 432def _expand_group_by(scope: Scope, dialect: Dialect) -> None: 433 expression = scope.expression 434 group = expression.args.get("group") 435 if not group: 436 return 437 438 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 439 expression.set("group", group) 440 441 442def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 443 for modifier_key in ("order", "distinct"): 444 modifier = scope.expression.args.get(modifier_key) 445 if isinstance(modifier, exp.Distinct): 446 modifier = modifier.args.get("on") 447 448 if not isinstance(modifier, exp.Expression): 449 continue 450 451 modifier_expressions = modifier.expressions 452 if modifier_key == "order": 453 modifier_expressions = [ordered.this for ordered in modifier_expressions] 454 455 for original, expanded in zip( 456 modifier_expressions, 457 _expand_positional_references( 458 scope, modifier_expressions, resolver.dialect, alias=True 459 ), 460 ): 461 for agg in original.find_all(exp.AggFunc): 462 for col in agg.find_all(exp.Column): 463 if not col.table: 464 col.set("table", resolver.get_table(col.name)) 465 466 original.replace(expanded) 467 468 if scope.expression.args.get("group"): 469 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 470 471 for expression in modifier_expressions: 472 expression.replace( 473 exp.to_identifier(_select_by_pos(scope, expression).alias) 474 if expression.is_int 475 else selects.get(expression, expression) 476 ) 477 478 479def _expand_positional_references( 480 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: Dialect, alias: bool = False 481) -> t.List[exp.Expression]: 482 new_nodes: t.List[exp.Expression] = [] 483 ambiguous_projections = None 484 485 for node in expressions: 486 if node.is_int: 487 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 488 489 if alias: 490 new_nodes.append(exp.column(select.args["alias"].copy())) 491 else: 492 select = select.this 493 494 if dialect.PROJECTION_ALIASES_SHADOW_SOURCE_NAMES: 495 if ambiguous_projections is None: 496 # When a projection name is also a source name and it is referenced in the 497 # GROUP BY clause, BQ can't understand what the identifier corresponds to 498 ambiguous_projections = { 499 s.alias_or_name 500 for s in scope.expression.selects 501 if s.alias_or_name in scope.selected_sources 502 } 503 504 ambiguous = any( 505 column.parts[0].name in ambiguous_projections 506 for column in select.find_all(exp.Column) 507 ) 508 else: 509 ambiguous = False 510 511 if ( 512 isinstance(select, exp.CONSTANTS) 513 or select.is_number 514 or select.find(exp.Explode, exp.Unnest) 515 or ambiguous 516 ): 517 new_nodes.append(node) 518 else: 519 new_nodes.append(select.copy()) 520 else: 521 new_nodes.append(node) 522 523 return new_nodes 524 525 526def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 527 try: 528 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 529 except IndexError: 530 raise OptimizeError(f"Unknown output column: {node.name}") 531 532 533def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 534 """ 535 Converts `Column` instances that represent STRUCT or JSON field lookup into chained `Dots`. 536 537 These lookups may be parsed as columns (e.g. "col"."field"."field2"), but they need to be 538 normalized to `Dot(Dot(...(<table>.<column>, field1), field2, ...))` to be qualified properly. 539 """ 540 converted = False 541 for column in itertools.chain(scope.columns, scope.stars): 542 if isinstance(column, exp.Dot): 543 continue 544 545 column_table: t.Optional[str | exp.Identifier] = column.table 546 dot_parts = column.meta.pop("dot_parts", []) 547 if ( 548 column_table 549 and column_table not in scope.sources 550 and ( 551 not scope.parent 552 or column_table not in scope.parent.sources 553 or not scope.is_correlated_subquery 554 ) 555 ): 556 root, *parts = column.parts 557 558 if root.name in scope.sources: 559 # The struct is already qualified, but we still need to change the AST 560 column_table = root 561 root, *parts = parts 562 was_qualified = True 563 else: 564 column_table = resolver.get_table(root.name) 565 was_qualified = False 566 567 if column_table: 568 converted = True 569 new_column = exp.column(root, table=column_table) 570 571 if dot_parts: 572 # Remove the actual column parts from the rest of dot parts 573 new_column.meta["dot_parts"] = dot_parts[2 if was_qualified else 1 :] 574 575 column.replace(exp.Dot.build([new_column, *parts])) 576 577 if converted: 578 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 579 # a `for column in scope.columns` iteration, even though they shouldn't be 580 scope.clear_cache() 581 582 583def _qualify_columns( 584 scope: Scope, 585 resolver: Resolver, 586 allow_partial_qualification: bool, 587) -> None: 588 """Disambiguate columns, ensuring each column specifies a source""" 589 for column in scope.columns: 590 column_table = column.table 591 column_name = column.name 592 593 if column_table and column_table in scope.sources: 594 source_columns = resolver.get_source_columns(column_table) 595 if ( 596 not allow_partial_qualification 597 and source_columns 598 and column_name not in source_columns 599 and "*" not in source_columns 600 ): 601 raise OptimizeError(f"Unknown column: {column_name}") 602 603 if not column_table: 604 if scope.pivots and not column.find_ancestor(exp.Pivot): 605 # If the column is under the Pivot expression, we need to qualify it 606 # using the name of the pivoted source instead of the pivot's alias 607 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 608 continue 609 610 # column_table can be a '' because bigquery unnest has no table alias 611 column_table = resolver.get_table(column) 612 613 if ( 614 column_table 615 and isinstance(source := scope.sources.get(column_table.name), Scope) 616 and id(column) in source.column_index 617 ): 618 continue 619 620 if column_table: 621 column.set("table", column_table) 622 elif ( 623 resolver.dialect.TABLES_REFERENCEABLE_AS_COLUMNS 624 and len(column.parts) == 1 625 and column_name in scope.selected_sources 626 ): 627 # BigQuery and Postgres allow tables to be referenced as columns, treating them as structs/records 628 scope.replace(column, exp.TableColumn(this=column.this)) 629 630 for pivot in scope.pivots: 631 for column in pivot.find_all(exp.Column): 632 if not column.table and column.name in resolver.all_columns: 633 column_table = resolver.get_table(column.name) 634 if column_table: 635 column.set("table", column_table) 636 637 638def _expand_struct_stars_no_parens( 639 expression: exp.Dot, 640) -> t.List[exp.Alias]: 641 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 642 643 dot_column = expression.find(exp.Column) 644 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 645 return [] 646 647 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 648 dot_column = dot_column.copy() 649 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 650 651 # First part is the table name and last part is the star so they can be dropped 652 dot_parts = expression.parts[1:-1] 653 654 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 655 for part in dot_parts[1:]: 656 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 657 # Unable to expand star unless all fields are named 658 if not isinstance(field.this, exp.Identifier): 659 return [] 660 661 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 662 starting_struct = field 663 break 664 else: 665 # There is no matching field in the struct 666 return [] 667 668 taken_names = set() 669 new_selections = [] 670 671 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 672 name = field.name 673 674 # Ambiguous or anonymous fields can't be expanded 675 if name in taken_names or not isinstance(field.this, exp.Identifier): 676 return [] 677 678 taken_names.add(name) 679 680 this = field.this.copy() 681 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 682 new_column = exp.column( 683 t.cast(exp.Identifier, root), 684 table=dot_column.args.get("table"), 685 fields=t.cast(t.List[exp.Identifier], parts), 686 ) 687 new_selections.append(alias(new_column, this, copy=False)) 688 689 return new_selections 690 691 692def _expand_struct_stars_with_parens(expression: exp.Dot) -> t.List[exp.Alias]: 693 """[RisingWave] Expand/Flatten (<exp>.bar).*, where bar is a struct column""" 694 695 # it is not (<sub_exp>).* pattern, which means we can't expand 696 if not isinstance(expression.this, exp.Paren): 697 return [] 698 699 # find column definition to get data-type 700 dot_column = expression.find(exp.Column) 701 if not isinstance(dot_column, exp.Column) or not dot_column.is_type(exp.DataType.Type.STRUCT): 702 return [] 703 704 parent = dot_column.parent 705 starting_struct = dot_column.type 706 707 # walk up AST and down into struct definition in sync 708 while parent is not None: 709 if isinstance(parent, exp.Paren): 710 parent = parent.parent 711 continue 712 713 # if parent is not a dot, then something is wrong 714 if not isinstance(parent, exp.Dot): 715 return [] 716 717 # if the rhs of the dot is star we are done 718 rhs = parent.right 719 if isinstance(rhs, exp.Star): 720 break 721 722 # if it is not identifier, then something is wrong 723 if not isinstance(rhs, exp.Identifier): 724 return [] 725 726 # Check if current rhs identifier is in struct 727 matched = False 728 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 729 if struct_field_def.name == rhs.name: 730 matched = True 731 starting_struct = struct_field_def.kind # update struct 732 break 733 734 if not matched: 735 return [] 736 737 parent = parent.parent 738 739 # build new aliases to expand star 740 new_selections = [] 741 742 # fetch the outermost parentheses for new aliaes 743 outer_paren = expression.this 744 745 for struct_field_def in t.cast(exp.DataType, starting_struct).expressions: 746 new_identifier = struct_field_def.this.copy() 747 new_dot = exp.Dot.build([outer_paren.copy(), new_identifier]) 748 new_alias = alias(new_dot, new_identifier, copy=False) 749 new_selections.append(new_alias) 750 751 return new_selections 752 753 754def _expand_stars( 755 scope: Scope, 756 resolver: Resolver, 757 using_column_tables: t.Dict[str, t.Any], 758 pseudocolumns: t.Set[str], 759 annotator: TypeAnnotator, 760) -> None: 761 """Expand stars to lists of column selections""" 762 763 new_selections: t.List[exp.Expression] = [] 764 except_columns: t.Dict[int, t.Set[str]] = {} 765 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 766 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 767 768 coalesced_columns = set() 769 dialect = resolver.dialect 770 771 pivot_output_columns = None 772 pivot_exclude_columns: t.Set[str] = set() 773 774 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 775 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 776 if pivot.unpivot: 777 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 778 779 for field in pivot.fields: 780 if isinstance(field, exp.In): 781 pivot_exclude_columns.update( 782 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 783 ) 784 785 else: 786 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 787 788 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 789 if not pivot_output_columns: 790 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 791 792 if dialect.SUPPORTS_STRUCT_STAR_EXPANSION and any( 793 isinstance(col, exp.Dot) for col in scope.stars 794 ): 795 # Found struct expansion, annotate scope ahead of time 796 annotator.annotate_scope(scope) 797 798 for expression in scope.expression.selects: 799 tables = [] 800 if isinstance(expression, exp.Star): 801 tables.extend(scope.selected_sources) 802 _add_except_columns(expression, tables, except_columns) 803 _add_replace_columns(expression, tables, replace_columns) 804 _add_rename_columns(expression, tables, rename_columns) 805 elif expression.is_star: 806 if not isinstance(expression, exp.Dot): 807 tables.append(expression.table) 808 _add_except_columns(expression.this, tables, except_columns) 809 _add_replace_columns(expression.this, tables, replace_columns) 810 _add_rename_columns(expression.this, tables, rename_columns) 811 elif ( 812 dialect.SUPPORTS_STRUCT_STAR_EXPANSION 813 and not dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS 814 ): 815 struct_fields = _expand_struct_stars_no_parens(expression) 816 if struct_fields: 817 new_selections.extend(struct_fields) 818 continue 819 elif dialect.REQUIRES_PARENTHESIZED_STRUCT_ACCESS: 820 struct_fields = _expand_struct_stars_with_parens(expression) 821 if struct_fields: 822 new_selections.extend(struct_fields) 823 continue 824 825 if not tables: 826 new_selections.append(expression) 827 continue 828 829 for table in tables: 830 if table not in scope.sources: 831 raise OptimizeError(f"Unknown table: {table}") 832 833 columns = resolver.get_source_columns(table, only_visible=True) 834 columns = columns or scope.outer_columns 835 836 if pseudocolumns and dialect.EXCLUDES_PSEUDOCOLUMNS_FROM_STAR: 837 columns = [name for name in columns if name.upper() not in pseudocolumns] 838 839 if not columns or "*" in columns: 840 return 841 842 table_id = id(table) 843 columns_to_exclude = except_columns.get(table_id) or set() 844 renamed_columns = rename_columns.get(table_id, {}) 845 replaced_columns = replace_columns.get(table_id, {}) 846 847 if pivot: 848 if pivot_output_columns and pivot_exclude_columns: 849 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 850 pivot_columns.extend(pivot_output_columns) 851 else: 852 pivot_columns = pivot.alias_column_names 853 854 if pivot_columns: 855 new_selections.extend( 856 alias(exp.column(name, table=pivot.alias), name, copy=False) 857 for name in pivot_columns 858 if name not in columns_to_exclude 859 ) 860 continue 861 862 for name in columns: 863 if name in columns_to_exclude or name in coalesced_columns: 864 continue 865 if name in using_column_tables and table in using_column_tables[name]: 866 coalesced_columns.add(name) 867 tables = using_column_tables[name] 868 coalesce_args = [exp.column(name, table=table) for table in tables] 869 870 new_selections.append( 871 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 872 ) 873 else: 874 alias_ = renamed_columns.get(name, name) 875 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 876 new_selections.append( 877 alias(selection_expr, alias_, copy=False) 878 if alias_ != name 879 else selection_expr 880 ) 881 882 # Ensures we don't overwrite the initial selections with an empty list 883 if new_selections and isinstance(scope.expression, exp.Select): 884 scope.expression.set("expressions", new_selections) 885 886 887def _add_except_columns( 888 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 889) -> None: 890 except_ = expression.args.get("except_") 891 892 if not except_: 893 return 894 895 columns = {e.name for e in except_} 896 897 for table in tables: 898 except_columns[id(table)] = columns 899 900 901def _add_rename_columns( 902 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 903) -> None: 904 rename = expression.args.get("rename") 905 906 if not rename: 907 return 908 909 columns = {e.this.name: e.alias for e in rename} 910 911 for table in tables: 912 rename_columns[id(table)] = columns 913 914 915def _add_replace_columns( 916 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 917) -> None: 918 replace = expression.args.get("replace") 919 920 if not replace: 921 return 922 923 columns = {e.alias: e for e in replace} 924 925 for table in tables: 926 replace_columns[id(table)] = columns 927 928 929def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 930 """Ensure all output columns are aliased""" 931 if isinstance(scope_or_expression, exp.Expression): 932 scope = build_scope(scope_or_expression) 933 if not isinstance(scope, Scope): 934 return 935 else: 936 scope = scope_or_expression 937 938 new_selections = [] 939 for i, (selection, aliased_column) in enumerate( 940 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 941 ): 942 if selection is None or isinstance(selection, exp.QueryTransform): 943 break 944 945 if isinstance(selection, exp.Subquery): 946 if not selection.output_name: 947 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 948 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 949 selection = alias( 950 selection, 951 alias=selection.output_name or f"_col_{i}", 952 copy=False, 953 ) 954 if aliased_column: 955 selection.set("alias", exp.to_identifier(aliased_column)) 956 957 new_selections.append(selection) 958 959 if new_selections and isinstance(scope.expression, exp.Select): 960 scope.expression.set("expressions", new_selections) 961 962 963def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 964 """Makes sure all identifiers that need to be quoted are quoted.""" 965 return expression.transform( 966 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 967 ) # type: ignore 968 969 970def pushdown_cte_alias_columns(scope: Scope) -> None: 971 """ 972 Pushes down the CTE alias columns into the projection, 973 974 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 975 976 Args: 977 scope: Scope to find ctes to pushdown aliases. 978 """ 979 for cte in scope.ctes: 980 if cte.alias_column_names and isinstance(cte.this, exp.Select): 981 new_expressions = [] 982 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 983 if isinstance(projection, exp.Alias): 984 projection.set("alias", exp.to_identifier(_alias)) 985 else: 986 projection = alias(projection, alias=_alias) 987 new_expressions.append(projection) 988 cte.this.set("expressions", new_expressions)
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
21def qualify_columns( 22 expression: exp.Expression, 23 schema: t.Dict | Schema, 24 expand_alias_refs: bool = True, 25 expand_stars: bool = True, 26 infer_schema: t.Optional[bool] = None, 27 allow_partial_qualification: bool = False, 28 dialect: DialectType = None, 29) -> exp.Expression: 30 """ 31 Rewrite sqlglot AST to have fully qualified columns. 32 33 Example: 34 >>> import sqlglot 35 >>> schema = {"tbl": {"col": "INT"}} 36 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 37 >>> qualify_columns(expression, schema).sql() 38 'SELECT tbl.col AS col FROM tbl' 39 40 Args: 41 expression: Expression to qualify. 42 schema: Database schema. 43 expand_alias_refs: Whether to expand references to aliases. 44 expand_stars: Whether to expand star queries. This is a necessary step 45 for most of the optimizer's rules to work; do not set to False unless you 46 know what you're doing! 47 infer_schema: Whether to infer the schema if missing. 48 allow_partial_qualification: Whether to allow partial qualification. 49 50 Returns: 51 The qualified expression. 52 53 Notes: 54 - Currently only handles a single PIVOT or UNPIVOT operator 55 """ 56 schema = ensure_schema(schema, dialect=dialect) 57 annotator = TypeAnnotator(schema) 58 infer_schema = schema.empty if infer_schema is None else infer_schema 59 dialect = schema.dialect or Dialect() 60 pseudocolumns = dialect.PSEUDOCOLUMNS 61 62 for scope in traverse_scope(expression): 63 if dialect.PREFER_CTE_ALIAS_COLUMN: 64 pushdown_cte_alias_columns(scope) 65 66 scope_expression = scope.expression 67 is_select = isinstance(scope_expression, exp.Select) 68 69 _separate_pseudocolumns(scope, pseudocolumns) 70 71 resolver = Resolver(scope, schema, infer_schema=infer_schema) 72 _pop_table_column_aliases(scope.ctes) 73 _pop_table_column_aliases(scope.derived_tables) 74 using_column_tables = _expand_using(scope, resolver) 75 76 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 77 _expand_alias_refs( 78 scope, 79 resolver, 80 dialect, 81 expand_only_groupby=dialect.EXPAND_ONLY_GROUP_ALIAS_REF, 82 ) 83 84 _convert_columns_to_dots(scope, resolver) 85 _qualify_columns( 86 scope, 87 resolver, 88 allow_partial_qualification=allow_partial_qualification, 89 ) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if dialect.ANNOTATE_ALL_SCOPES: 112 annotator.annotate_scope(scope) 113 114 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E, sql: Optional[str] = None) -> ~E:
117def validate_qualify_columns(expression: E, sql: t.Optional[str] = None) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 line = column.this.meta.get("line") 128 col = column.this.meta.get("col") 129 start = column.this.meta.get("start") 130 end = column.this.meta.get("end") 131 132 error_msg = f"Column '{column.name}' could not be resolved{for_table}." 133 if line and col: 134 error_msg += f" Line: {line}, Col: {col}" 135 if sql and start is not None and end is not None: 136 formatted_sql = highlight_sql(sql, [(start, end)])[0] 137 error_msg += f"\n {formatted_sql}" 138 139 raise OptimizeError(error_msg) 140 141 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 142 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 143 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 144 # this list here to ensure those in the former category will be excluded. 145 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 146 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 147 148 all_unqualified_columns.extend(unqualified_columns) 149 150 if all_unqualified_columns: 151 first_column = all_unqualified_columns[0] 152 line = first_column.this.meta.get("line") 153 col = first_column.this.meta.get("col") 154 start = first_column.this.meta.get("start") 155 end = first_column.this.meta.get("end") 156 157 error_msg = f"Ambiguous column '{first_column.name}'" 158 if line and col: 159 error_msg += f" (Line: {line}, Col: {col})" 160 if sql and start is not None and end is not None: 161 formatted_sql = highlight_sql(sql, [(start, end)])[0] 162 error_msg += f"\n {formatted_sql}" 163 164 raise OptimizeError(error_msg) 165 166 return expression
Raise an OptimizeError if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
930def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 931 """Ensure all output columns are aliased""" 932 if isinstance(scope_or_expression, exp.Expression): 933 scope = build_scope(scope_or_expression) 934 if not isinstance(scope, Scope): 935 return 936 else: 937 scope = scope_or_expression 938 939 new_selections = [] 940 for i, (selection, aliased_column) in enumerate( 941 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 942 ): 943 if selection is None or isinstance(selection, exp.QueryTransform): 944 break 945 946 if isinstance(selection, exp.Subquery): 947 if not selection.output_name: 948 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 949 elif not isinstance(selection, (exp.Alias, exp.Aliases)) and not selection.is_star: 950 selection = alias( 951 selection, 952 alias=selection.output_name or f"_col_{i}", 953 copy=False, 954 ) 955 if aliased_column: 956 selection.set("alias", exp.to_identifier(aliased_column)) 957 958 new_selections.append(selection) 959 960 if new_selections and isinstance(scope.expression, exp.Select): 961 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
964def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 965 """Makes sure all identifiers that need to be quoted are quoted.""" 966 return expression.transform( 967 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 968 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
971def pushdown_cte_alias_columns(scope: Scope) -> None: 972 """ 973 Pushes down the CTE alias columns into the projection, 974 975 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 976 977 Args: 978 scope: Scope to find ctes to pushdown aliases. 979 """ 980 for cte in scope.ctes: 981 if cte.alias_column_names and isinstance(cte.this, exp.Select): 982 new_expressions = [] 983 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 984 if isinstance(projection, exp.Alias): 985 projection.set("alias", exp.to_identifier(_alias)) 986 else: 987 projection = alias(projection, alias=_alias) 988 new_expressions.append(projection) 989 cte.this.set("expressions", new_expressions)
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Arguments:
- scope: Scope to find ctes to pushdown aliases.