sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26 dialect: DialectType = None, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema, dialect=dialect) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 bigquery = dialect == "bigquery" 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=bigquery, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if bigquery: 111 annotator.annotate_scope(scope) 112 113 return expression 114 115 116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression 141 142 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 144 name_columns = [ 145 field.this 146 for field in unpivot.fields 147 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 148 ] 149 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 150 151 return itertools.chain(name_columns, value_columns) 152 153 154def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 155 """ 156 Remove table column aliases. 157 158 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 159 """ 160 for derived_table in derived_tables: 161 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 162 continue 163 table_alias = derived_table.args.get("alias") 164 if table_alias: 165 table_alias.args.pop("columns", None) 166 167 168def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 169 columns = {} 170 171 def _update_source_columns(source_name: str) -> None: 172 for column_name in resolver.get_source_columns(source_name): 173 if column_name not in columns: 174 columns[column_name] = source_name 175 176 joins = list(scope.find_all(exp.Join)) 177 names = {join.alias_or_name for join in joins} 178 ordered = [key for key in scope.selected_sources if key not in names] 179 180 if names and not ordered: 181 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 182 183 # Mapping of automatically joined column names to an ordered set of source names (dict). 184 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 185 186 for source_name in ordered: 187 _update_source_columns(source_name) 188 189 for i, join in enumerate(joins): 190 source_table = ordered[-1] 191 if source_table: 192 _update_source_columns(source_table) 193 194 join_table = join.alias_or_name 195 ordered.append(join_table) 196 197 using = join.args.get("using") 198 if not using: 199 continue 200 201 join_columns = resolver.get_source_columns(join_table) 202 conditions = [] 203 using_identifier_count = len(using) 204 is_semi_or_anti_join = join.is_semi_or_anti_join 205 206 for identifier in using: 207 identifier = identifier.name 208 table = columns.get(identifier) 209 210 if not table or identifier not in join_columns: 211 if (columns and "*" not in columns) and join_columns: 212 raise OptimizeError(f"Cannot automatically join: {identifier}") 213 214 table = table or source_table 215 216 if i == 0 or using_identifier_count == 1: 217 lhs: exp.Expression = exp.column(identifier, table=table) 218 else: 219 coalesce_columns = [ 220 exp.column(identifier, table=t) 221 for t in ordered[:-1] 222 if identifier in resolver.get_source_columns(t) 223 ] 224 if len(coalesce_columns) > 1: 225 lhs = exp.func("coalesce", *coalesce_columns) 226 else: 227 lhs = exp.column(identifier, table=table) 228 229 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 230 231 # Set all values in the dict to None, because we only care about the key ordering 232 tables = column_tables.setdefault(identifier, {}) 233 234 # Do not update the dict if this was a SEMI/ANTI join in 235 # order to avoid generating COALESCE columns for this join pair 236 if not is_semi_or_anti_join: 237 if table not in tables: 238 tables[table] = None 239 if join_table not in tables: 240 tables[join_table] = None 241 242 join.args.pop("using") 243 join.set("on", exp.and_(*conditions, copy=False)) 244 245 if column_tables: 246 for column in scope.columns: 247 if not column.table and column.name in column_tables: 248 tables = column_tables[column.name] 249 coalesce_args = [exp.column(column.name, table=table) for table in tables] 250 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 251 252 if isinstance(column.parent, exp.Select): 253 # Ensure the USING column keeps its name if it's projected 254 replacement = alias(replacement, alias=column.name, copy=False) 255 elif isinstance(column.parent, exp.Struct): 256 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 257 replacement = exp.PropertyEQ( 258 this=exp.to_identifier(column.name), expression=replacement 259 ) 260 261 scope.replace(column, replacement) 262 263 return column_tables 264 265 266def _expand_alias_refs( 267 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 268) -> None: 269 """ 270 Expand references to aliases. 271 Example: 272 SELECT y.foo AS bar, bar * 2 AS baz FROM y 273 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 274 """ 275 expression = scope.expression 276 277 if not isinstance(expression, exp.Select) or dialect == "oracle": 278 return 279 280 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 281 projections = {s.alias_or_name for s in expression.selects} 282 283 def replace_columns( 284 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 285 ) -> None: 286 is_group_by = isinstance(node, exp.Group) 287 is_having = isinstance(node, exp.Having) 288 if not node or (expand_only_groupby and not is_group_by): 289 return 290 291 for column in walk_in_scope(node, prune=lambda node: node.is_star): 292 if not isinstance(column, exp.Column): 293 continue 294 295 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 296 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 297 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 298 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 299 if expand_only_groupby and is_group_by and column.parent is not node: 300 continue 301 302 skip_replace = False 303 table = resolver.get_table(column.name) if resolve_table and not column.table else None 304 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 305 306 if alias_expr: 307 skip_replace = bool( 308 alias_expr.find(exp.AggFunc) 309 and column.find_ancestor(exp.AggFunc) 310 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 311 ) 312 313 # BigQuery's having clause gets confused if an alias matches a source. 314 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 315 # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table 316 if is_having and dialect == "bigquery": 317 skip_replace = skip_replace or any( 318 node.parts[0].name in projections 319 for node in alias_expr.find_all(exp.Column) 320 ) 321 322 if table and (not alias_expr or skip_replace): 323 column.set("table", table) 324 elif not column.table and alias_expr and not skip_replace: 325 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 326 if literal_index: 327 column.replace(exp.Literal.number(i)) 328 else: 329 column = column.replace(exp.paren(alias_expr)) 330 simplified = simplify_parens(column) 331 if simplified is not column: 332 column.replace(simplified) 333 334 for i, projection in enumerate(expression.selects): 335 replace_columns(projection) 336 if isinstance(projection, exp.Alias): 337 alias_to_expression[projection.alias] = (projection.this, i + 1) 338 339 parent_scope = scope 340 while parent_scope.is_union: 341 parent_scope = parent_scope.parent 342 343 # We shouldn't expand aliases if they match the recursive CTE's columns 344 if parent_scope.is_cte: 345 cte = parent_scope.expression.parent 346 if cte.find_ancestor(exp.With).recursive: 347 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 348 alias_to_expression.pop(recursive_cte_column.output_name, None) 349 350 replace_columns(expression.args.get("where")) 351 replace_columns(expression.args.get("group"), literal_index=True) 352 replace_columns(expression.args.get("having"), resolve_table=True) 353 replace_columns(expression.args.get("qualify"), resolve_table=True) 354 355 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 356 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 357 if dialect == "snowflake": 358 for join in expression.args.get("joins") or []: 359 replace_columns(join) 360 361 scope.clear_cache() 362 363 364def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 365 expression = scope.expression 366 group = expression.args.get("group") 367 if not group: 368 return 369 370 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 371 expression.set("group", group) 372 373 374def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 375 for modifier_key in ("order", "distinct"): 376 modifier = scope.expression.args.get(modifier_key) 377 if isinstance(modifier, exp.Distinct): 378 modifier = modifier.args.get("on") 379 380 if not isinstance(modifier, exp.Expression): 381 continue 382 383 modifier_expressions = modifier.expressions 384 if modifier_key == "order": 385 modifier_expressions = [ordered.this for ordered in modifier_expressions] 386 387 for original, expanded in zip( 388 modifier_expressions, 389 _expand_positional_references( 390 scope, modifier_expressions, resolver.schema.dialect, alias=True 391 ), 392 ): 393 for agg in original.find_all(exp.AggFunc): 394 for col in agg.find_all(exp.Column): 395 if not col.table: 396 col.set("table", resolver.get_table(col.name)) 397 398 original.replace(expanded) 399 400 if scope.expression.args.get("group"): 401 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 402 403 for expression in modifier_expressions: 404 expression.replace( 405 exp.to_identifier(_select_by_pos(scope, expression).alias) 406 if expression.is_int 407 else selects.get(expression, expression) 408 ) 409 410 411def _expand_positional_references( 412 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 413) -> t.List[exp.Expression]: 414 new_nodes: t.List[exp.Expression] = [] 415 ambiguous_projections = None 416 417 for node in expressions: 418 if node.is_int: 419 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 420 421 if alias: 422 new_nodes.append(exp.column(select.args["alias"].copy())) 423 else: 424 select = select.this 425 426 if dialect == "bigquery": 427 if ambiguous_projections is None: 428 # When a projection name is also a source name and it is referenced in the 429 # GROUP BY clause, BQ can't understand what the identifier corresponds to 430 ambiguous_projections = { 431 s.alias_or_name 432 for s in scope.expression.selects 433 if s.alias_or_name in scope.selected_sources 434 } 435 436 ambiguous = any( 437 column.parts[0].name in ambiguous_projections 438 for column in select.find_all(exp.Column) 439 ) 440 else: 441 ambiguous = False 442 443 if ( 444 isinstance(select, exp.CONSTANTS) 445 or select.find(exp.Explode, exp.Unnest) 446 or ambiguous 447 ): 448 new_nodes.append(node) 449 else: 450 new_nodes.append(select.copy()) 451 else: 452 new_nodes.append(node) 453 454 return new_nodes 455 456 457def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 458 try: 459 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 460 except IndexError: 461 raise OptimizeError(f"Unknown output column: {node.name}") 462 463 464def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 465 """ 466 Converts `Column` instances that represent struct field lookup into chained `Dots`. 467 468 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 469 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 470 """ 471 converted = False 472 for column in itertools.chain(scope.columns, scope.stars): 473 if isinstance(column, exp.Dot): 474 continue 475 476 column_table: t.Optional[str | exp.Identifier] = column.table 477 if ( 478 column_table 479 and column_table not in scope.sources 480 and ( 481 not scope.parent 482 or column_table not in scope.parent.sources 483 or not scope.is_correlated_subquery 484 ) 485 ): 486 root, *parts = column.parts 487 488 if root.name in scope.sources: 489 # The struct is already qualified, but we still need to change the AST 490 column_table = root 491 root, *parts = parts 492 else: 493 column_table = resolver.get_table(root.name) 494 495 if column_table: 496 converted = True 497 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 498 499 if converted: 500 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 501 # a `for column in scope.columns` iteration, even though they shouldn't be 502 scope.clear_cache() 503 504 505def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 506 """Disambiguate columns, ensuring each column specifies a source""" 507 for column in scope.columns: 508 column_table = column.table 509 column_name = column.name 510 511 if column_table and column_table in scope.sources: 512 source_columns = resolver.get_source_columns(column_table) 513 if ( 514 not allow_partial_qualification 515 and source_columns 516 and column_name not in source_columns 517 and "*" not in source_columns 518 ): 519 raise OptimizeError(f"Unknown column: {column_name}") 520 521 if not column_table: 522 if scope.pivots and not column.find_ancestor(exp.Pivot): 523 # If the column is under the Pivot expression, we need to qualify it 524 # using the name of the pivoted source instead of the pivot's alias 525 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 526 continue 527 528 # column_table can be a '' because bigquery unnest has no table alias 529 column_table = resolver.get_table(column_name) 530 if column_table: 531 column.set("table", column_table) 532 533 for pivot in scope.pivots: 534 for column in pivot.find_all(exp.Column): 535 if not column.table and column.name in resolver.all_columns: 536 column_table = resolver.get_table(column.name) 537 if column_table: 538 column.set("table", column_table) 539 540 541def _expand_struct_stars( 542 expression: exp.Dot, 543) -> t.List[exp.Alias]: 544 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 545 546 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 547 if not dot_column.is_type(exp.DataType.Type.STRUCT): 548 return [] 549 550 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 551 dot_column = dot_column.copy() 552 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 553 554 # First part is the table name and last part is the star so they can be dropped 555 dot_parts = expression.parts[1:-1] 556 557 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 558 for part in dot_parts[1:]: 559 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 560 # Unable to expand star unless all fields are named 561 if not isinstance(field.this, exp.Identifier): 562 return [] 563 564 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 565 starting_struct = field 566 break 567 else: 568 # There is no matching field in the struct 569 return [] 570 571 taken_names = set() 572 new_selections = [] 573 574 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 575 name = field.name 576 577 # Ambiguous or anonymous fields can't be expanded 578 if name in taken_names or not isinstance(field.this, exp.Identifier): 579 return [] 580 581 taken_names.add(name) 582 583 this = field.this.copy() 584 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 585 new_column = exp.column( 586 t.cast(exp.Identifier, root), 587 table=dot_column.args.get("table"), 588 fields=t.cast(t.List[exp.Identifier], parts), 589 ) 590 new_selections.append(alias(new_column, this, copy=False)) 591 592 return new_selections 593 594 595def _expand_stars( 596 scope: Scope, 597 resolver: Resolver, 598 using_column_tables: t.Dict[str, t.Any], 599 pseudocolumns: t.Set[str], 600 annotator: TypeAnnotator, 601) -> None: 602 """Expand stars to lists of column selections""" 603 604 new_selections: t.List[exp.Expression] = [] 605 except_columns: t.Dict[int, t.Set[str]] = {} 606 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 607 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 608 609 coalesced_columns = set() 610 dialect = resolver.schema.dialect 611 612 pivot_output_columns = None 613 pivot_exclude_columns: t.Set[str] = set() 614 615 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 616 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 617 if pivot.unpivot: 618 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 619 620 for field in pivot.fields: 621 if isinstance(field, exp.In): 622 pivot_exclude_columns.update( 623 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 624 ) 625 626 else: 627 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 628 629 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 630 if not pivot_output_columns: 631 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 632 633 is_bigquery = dialect == "bigquery" 634 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 635 # Found struct expansion, annotate scope ahead of time 636 annotator.annotate_scope(scope) 637 638 for expression in scope.expression.selects: 639 tables = [] 640 if isinstance(expression, exp.Star): 641 tables.extend(scope.selected_sources) 642 _add_except_columns(expression, tables, except_columns) 643 _add_replace_columns(expression, tables, replace_columns) 644 _add_rename_columns(expression, tables, rename_columns) 645 elif expression.is_star: 646 if not isinstance(expression, exp.Dot): 647 tables.append(expression.table) 648 _add_except_columns(expression.this, tables, except_columns) 649 _add_replace_columns(expression.this, tables, replace_columns) 650 _add_rename_columns(expression.this, tables, rename_columns) 651 elif is_bigquery: 652 struct_fields = _expand_struct_stars(expression) 653 if struct_fields: 654 new_selections.extend(struct_fields) 655 continue 656 657 if not tables: 658 new_selections.append(expression) 659 continue 660 661 for table in tables: 662 if table not in scope.sources: 663 raise OptimizeError(f"Unknown table: {table}") 664 665 columns = resolver.get_source_columns(table, only_visible=True) 666 columns = columns or scope.outer_columns 667 668 if pseudocolumns: 669 columns = [name for name in columns if name.upper() not in pseudocolumns] 670 671 if not columns or "*" in columns: 672 return 673 674 table_id = id(table) 675 columns_to_exclude = except_columns.get(table_id) or set() 676 renamed_columns = rename_columns.get(table_id, {}) 677 replaced_columns = replace_columns.get(table_id, {}) 678 679 if pivot: 680 if pivot_output_columns and pivot_exclude_columns: 681 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 682 pivot_columns.extend(pivot_output_columns) 683 else: 684 pivot_columns = pivot.alias_column_names 685 686 if pivot_columns: 687 new_selections.extend( 688 alias(exp.column(name, table=pivot.alias), name, copy=False) 689 for name in pivot_columns 690 if name not in columns_to_exclude 691 ) 692 continue 693 694 for name in columns: 695 if name in columns_to_exclude or name in coalesced_columns: 696 continue 697 if name in using_column_tables and table in using_column_tables[name]: 698 coalesced_columns.add(name) 699 tables = using_column_tables[name] 700 coalesce_args = [exp.column(name, table=table) for table in tables] 701 702 new_selections.append( 703 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 704 ) 705 else: 706 alias_ = renamed_columns.get(name, name) 707 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 708 new_selections.append( 709 alias(selection_expr, alias_, copy=False) 710 if alias_ != name 711 else selection_expr 712 ) 713 714 # Ensures we don't overwrite the initial selections with an empty list 715 if new_selections and isinstance(scope.expression, exp.Select): 716 scope.expression.set("expressions", new_selections) 717 718 719def _add_except_columns( 720 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 721) -> None: 722 except_ = expression.args.get("except") 723 724 if not except_: 725 return 726 727 columns = {e.name for e in except_} 728 729 for table in tables: 730 except_columns[id(table)] = columns 731 732 733def _add_rename_columns( 734 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 735) -> None: 736 rename = expression.args.get("rename") 737 738 if not rename: 739 return 740 741 columns = {e.this.name: e.alias for e in rename} 742 743 for table in tables: 744 rename_columns[id(table)] = columns 745 746 747def _add_replace_columns( 748 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 749) -> None: 750 replace = expression.args.get("replace") 751 752 if not replace: 753 return 754 755 columns = {e.alias: e for e in replace} 756 757 for table in tables: 758 replace_columns[id(table)] = columns 759 760 761def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 762 """Ensure all output columns are aliased""" 763 if isinstance(scope_or_expression, exp.Expression): 764 scope = build_scope(scope_or_expression) 765 if not isinstance(scope, Scope): 766 return 767 else: 768 scope = scope_or_expression 769 770 new_selections = [] 771 for i, (selection, aliased_column) in enumerate( 772 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 773 ): 774 if selection is None or isinstance(selection, exp.QueryTransform): 775 break 776 777 if isinstance(selection, exp.Subquery): 778 if not selection.output_name: 779 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 780 elif not isinstance(selection, exp.Alias) and not selection.is_star: 781 selection = alias( 782 selection, 783 alias=selection.output_name or f"_col_{i}", 784 copy=False, 785 ) 786 if aliased_column: 787 selection.set("alias", exp.to_identifier(aliased_column)) 788 789 new_selections.append(selection) 790 791 if new_selections and isinstance(scope.expression, exp.Select): 792 scope.expression.set("expressions", new_selections) 793 794 795def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 796 """Makes sure all identifiers that need to be quoted are quoted.""" 797 return expression.transform( 798 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 799 ) # type: ignore 800 801 802def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 803 """ 804 Pushes down the CTE alias columns into the projection, 805 806 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 807 808 Example: 809 >>> import sqlglot 810 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 811 >>> pushdown_cte_alias_columns(expression).sql() 812 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 813 814 Args: 815 expression: Expression to pushdown. 816 817 Returns: 818 The expression with the CTE aliases pushed down into the projection. 819 """ 820 for cte in expression.find_all(exp.CTE): 821 if cte.alias_column_names: 822 new_expressions = [] 823 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 824 if isinstance(projection, exp.Alias): 825 projection.set("alias", _alias) 826 else: 827 projection = alias(projection, alias=_alias) 828 new_expressions.append(projection) 829 cte.this.set("expressions", new_expressions) 830 831 return expression 832 833 834class Resolver: 835 """ 836 Helper for resolving columns. 837 838 This is a class so we can lazily load some things and easily share them across functions. 839 """ 840 841 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 842 self.scope = scope 843 self.schema = schema 844 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 845 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 846 self._all_columns: t.Optional[t.Set[str]] = None 847 self._infer_schema = infer_schema 848 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 849 850 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 851 """ 852 Get the table for a column name. 853 854 Args: 855 column_name: The column name to find the table for. 856 Returns: 857 The table name if it can be found/inferred. 858 """ 859 if self._unambiguous_columns is None: 860 self._unambiguous_columns = self._get_unambiguous_columns( 861 self._get_all_source_columns() 862 ) 863 864 table_name = self._unambiguous_columns.get(column_name) 865 866 if not table_name and self._infer_schema: 867 sources_without_schema = tuple( 868 source 869 for source, columns in self._get_all_source_columns().items() 870 if not columns or "*" in columns 871 ) 872 if len(sources_without_schema) == 1: 873 table_name = sources_without_schema[0] 874 875 if table_name not in self.scope.selected_sources: 876 return exp.to_identifier(table_name) 877 878 node, _ = self.scope.selected_sources.get(table_name) 879 880 if isinstance(node, exp.Query): 881 while node and node.alias != table_name: 882 node = node.parent 883 884 node_alias = node.args.get("alias") 885 if node_alias: 886 return exp.to_identifier(node_alias.this) 887 888 return exp.to_identifier(table_name) 889 890 @property 891 def all_columns(self) -> t.Set[str]: 892 """All available columns of all sources in this scope""" 893 if self._all_columns is None: 894 self._all_columns = { 895 column for columns in self._get_all_source_columns().values() for column in columns 896 } 897 return self._all_columns 898 899 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 900 """Resolve the source columns for a given source `name`.""" 901 cache_key = (name, only_visible) 902 if cache_key not in self._get_source_columns_cache: 903 if name not in self.scope.sources: 904 raise OptimizeError(f"Unknown table: {name}") 905 906 source = self.scope.sources[name] 907 908 if isinstance(source, exp.Table): 909 columns = self.schema.column_names(source, only_visible) 910 elif isinstance(source, Scope) and isinstance( 911 source.expression, (exp.Values, exp.Unnest) 912 ): 913 columns = source.expression.named_selects 914 915 # in bigquery, unnest structs are automatically scoped as tables, so you can 916 # directly select a struct field in a query. 917 # this handles the case where the unnest is statically defined. 918 if self.schema.dialect == "bigquery": 919 if source.expression.is_type(exp.DataType.Type.STRUCT): 920 for k in source.expression.type.expressions: # type: ignore 921 columns.append(k.name) 922 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 923 set_op = source.expression 924 925 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 926 on_column_list = set_op.args.get("on") 927 928 if on_column_list: 929 # The resulting columns are the columns in the ON clause: 930 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 931 columns = [col.name for col in on_column_list] 932 elif set_op.side or set_op.kind: 933 side = set_op.side 934 kind = set_op.kind 935 936 left = set_op.left.named_selects 937 right = set_op.right.named_selects 938 939 # We use dict.fromkeys to deduplicate keys and maintain insertion order 940 if side == "LEFT": 941 columns = left 942 elif side == "FULL": 943 columns = list(dict.fromkeys(left + right)) 944 elif kind == "INNER": 945 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 946 else: 947 columns = set_op.named_selects 948 else: 949 select = seq_get(source.expression.selects, 0) 950 951 if isinstance(select, exp.QueryTransform): 952 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 953 schema = select.args.get("schema") 954 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 955 else: 956 columns = source.expression.named_selects 957 958 node, _ = self.scope.selected_sources.get(name) or (None, None) 959 if isinstance(node, Scope): 960 column_aliases = node.expression.alias_column_names 961 elif isinstance(node, exp.Expression): 962 column_aliases = node.alias_column_names 963 else: 964 column_aliases = [] 965 966 if column_aliases: 967 # If the source's columns are aliased, their aliases shadow the corresponding column names. 968 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 969 columns = [ 970 alias or name 971 for (name, alias) in itertools.zip_longest(columns, column_aliases) 972 ] 973 974 self._get_source_columns_cache[cache_key] = columns 975 976 return self._get_source_columns_cache[cache_key] 977 978 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 979 if self._source_columns is None: 980 self._source_columns = { 981 source_name: self.get_source_columns(source_name) 982 for source_name, source in itertools.chain( 983 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 984 ) 985 } 986 return self._source_columns 987 988 def _get_unambiguous_columns( 989 self, source_columns: t.Dict[str, t.Sequence[str]] 990 ) -> t.Mapping[str, str]: 991 """ 992 Find all the unambiguous columns in sources. 993 994 Args: 995 source_columns: Mapping of names to source columns. 996 997 Returns: 998 Mapping of column name to source name. 999 """ 1000 if not source_columns: 1001 return {} 1002 1003 source_columns_pairs = list(source_columns.items()) 1004 1005 first_table, first_columns = source_columns_pairs[0] 1006 1007 if len(source_columns_pairs) == 1: 1008 # Performance optimization - avoid copying first_columns if there is only one table. 1009 return SingleValuedMapping(first_columns, first_table) 1010 1011 unambiguous_columns = {col: first_table for col in first_columns} 1012 all_columns = set(unambiguous_columns) 1013 1014 for table, columns in source_columns_pairs[1:]: 1015 unique = set(columns) 1016 ambiguous = all_columns.intersection(unique) 1017 all_columns.update(columns) 1018 1019 for column in ambiguous: 1020 unambiguous_columns.pop(column, None) 1021 for column in unique.difference(ambiguous): 1022 unambiguous_columns[column] = table 1023 1024 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27 dialect: DialectType = None, 28) -> exp.Expression: 29 """ 30 Rewrite sqlglot AST to have fully qualified columns. 31 32 Example: 33 >>> import sqlglot 34 >>> schema = {"tbl": {"col": "INT"}} 35 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 36 >>> qualify_columns(expression, schema).sql() 37 'SELECT tbl.col AS col FROM tbl' 38 39 Args: 40 expression: Expression to qualify. 41 schema: Database schema. 42 expand_alias_refs: Whether to expand references to aliases. 43 expand_stars: Whether to expand star queries. This is a necessary step 44 for most of the optimizer's rules to work; do not set to False unless you 45 know what you're doing! 46 infer_schema: Whether to infer the schema if missing. 47 allow_partial_qualification: Whether to allow partial qualification. 48 49 Returns: 50 The qualified expression. 51 52 Notes: 53 - Currently only handles a single PIVOT or UNPIVOT operator 54 """ 55 schema = ensure_schema(schema, dialect=dialect) 56 annotator = TypeAnnotator(schema) 57 infer_schema = schema.empty if infer_schema is None else infer_schema 58 dialect = Dialect.get_or_raise(schema.dialect) 59 pseudocolumns = dialect.PSEUDOCOLUMNS 60 bigquery = dialect == "bigquery" 61 62 for scope in traverse_scope(expression): 63 scope_expression = scope.expression 64 is_select = isinstance(scope_expression, exp.Select) 65 66 if is_select and scope_expression.args.get("connect"): 67 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 68 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 69 scope_expression.transform( 70 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 71 copy=False, 72 ) 73 scope.clear_cache() 74 75 resolver = Resolver(scope, schema, infer_schema=infer_schema) 76 _pop_table_column_aliases(scope.ctes) 77 _pop_table_column_aliases(scope.derived_tables) 78 using_column_tables = _expand_using(scope, resolver) 79 80 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 81 _expand_alias_refs( 82 scope, 83 resolver, 84 dialect, 85 expand_only_groupby=bigquery, 86 ) 87 88 _convert_columns_to_dots(scope, resolver) 89 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if bigquery: 112 annotator.annotate_scope(scope) 113 114 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 128 129 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 130 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 131 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 132 # this list here to ensure those in the former category will be excluded. 133 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 134 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 135 136 all_unqualified_columns.extend(unqualified_columns) 137 138 if all_unqualified_columns: 139 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 140 141 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
762def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 763 """Ensure all output columns are aliased""" 764 if isinstance(scope_or_expression, exp.Expression): 765 scope = build_scope(scope_or_expression) 766 if not isinstance(scope, Scope): 767 return 768 else: 769 scope = scope_or_expression 770 771 new_selections = [] 772 for i, (selection, aliased_column) in enumerate( 773 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 774 ): 775 if selection is None or isinstance(selection, exp.QueryTransform): 776 break 777 778 if isinstance(selection, exp.Subquery): 779 if not selection.output_name: 780 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 781 elif not isinstance(selection, exp.Alias) and not selection.is_star: 782 selection = alias( 783 selection, 784 alias=selection.output_name or f"_col_{i}", 785 copy=False, 786 ) 787 if aliased_column: 788 selection.set("alias", exp.to_identifier(aliased_column)) 789 790 new_selections.append(selection) 791 792 if new_selections and isinstance(scope.expression, exp.Select): 793 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
796def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 797 """Makes sure all identifiers that need to be quoted are quoted.""" 798 return expression.transform( 799 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 800 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
803def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 804 """ 805 Pushes down the CTE alias columns into the projection, 806 807 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 808 809 Example: 810 >>> import sqlglot 811 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 812 >>> pushdown_cte_alias_columns(expression).sql() 813 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 814 815 Args: 816 expression: Expression to pushdown. 817 818 Returns: 819 The expression with the CTE aliases pushed down into the projection. 820 """ 821 for cte in expression.find_all(exp.CTE): 822 if cte.alias_column_names: 823 new_expressions = [] 824 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 825 if isinstance(projection, exp.Alias): 826 projection.set("alias", _alias) 827 else: 828 projection = alias(projection, alias=_alias) 829 new_expressions.append(projection) 830 cte.this.set("expressions", new_expressions) 831 832 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
835class Resolver: 836 """ 837 Helper for resolving columns. 838 839 This is a class so we can lazily load some things and easily share them across functions. 840 """ 841 842 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 843 self.scope = scope 844 self.schema = schema 845 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 846 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 847 self._all_columns: t.Optional[t.Set[str]] = None 848 self._infer_schema = infer_schema 849 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 850 851 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 852 """ 853 Get the table for a column name. 854 855 Args: 856 column_name: The column name to find the table for. 857 Returns: 858 The table name if it can be found/inferred. 859 """ 860 if self._unambiguous_columns is None: 861 self._unambiguous_columns = self._get_unambiguous_columns( 862 self._get_all_source_columns() 863 ) 864 865 table_name = self._unambiguous_columns.get(column_name) 866 867 if not table_name and self._infer_schema: 868 sources_without_schema = tuple( 869 source 870 for source, columns in self._get_all_source_columns().items() 871 if not columns or "*" in columns 872 ) 873 if len(sources_without_schema) == 1: 874 table_name = sources_without_schema[0] 875 876 if table_name not in self.scope.selected_sources: 877 return exp.to_identifier(table_name) 878 879 node, _ = self.scope.selected_sources.get(table_name) 880 881 if isinstance(node, exp.Query): 882 while node and node.alias != table_name: 883 node = node.parent 884 885 node_alias = node.args.get("alias") 886 if node_alias: 887 return exp.to_identifier(node_alias.this) 888 889 return exp.to_identifier(table_name) 890 891 @property 892 def all_columns(self) -> t.Set[str]: 893 """All available columns of all sources in this scope""" 894 if self._all_columns is None: 895 self._all_columns = { 896 column for columns in self._get_all_source_columns().values() for column in columns 897 } 898 return self._all_columns 899 900 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 901 """Resolve the source columns for a given source `name`.""" 902 cache_key = (name, only_visible) 903 if cache_key not in self._get_source_columns_cache: 904 if name not in self.scope.sources: 905 raise OptimizeError(f"Unknown table: {name}") 906 907 source = self.scope.sources[name] 908 909 if isinstance(source, exp.Table): 910 columns = self.schema.column_names(source, only_visible) 911 elif isinstance(source, Scope) and isinstance( 912 source.expression, (exp.Values, exp.Unnest) 913 ): 914 columns = source.expression.named_selects 915 916 # in bigquery, unnest structs are automatically scoped as tables, so you can 917 # directly select a struct field in a query. 918 # this handles the case where the unnest is statically defined. 919 if self.schema.dialect == "bigquery": 920 if source.expression.is_type(exp.DataType.Type.STRUCT): 921 for k in source.expression.type.expressions: # type: ignore 922 columns.append(k.name) 923 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 924 set_op = source.expression 925 926 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 927 on_column_list = set_op.args.get("on") 928 929 if on_column_list: 930 # The resulting columns are the columns in the ON clause: 931 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 932 columns = [col.name for col in on_column_list] 933 elif set_op.side or set_op.kind: 934 side = set_op.side 935 kind = set_op.kind 936 937 left = set_op.left.named_selects 938 right = set_op.right.named_selects 939 940 # We use dict.fromkeys to deduplicate keys and maintain insertion order 941 if side == "LEFT": 942 columns = left 943 elif side == "FULL": 944 columns = list(dict.fromkeys(left + right)) 945 elif kind == "INNER": 946 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 947 else: 948 columns = set_op.named_selects 949 else: 950 select = seq_get(source.expression.selects, 0) 951 952 if isinstance(select, exp.QueryTransform): 953 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 954 schema = select.args.get("schema") 955 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 956 else: 957 columns = source.expression.named_selects 958 959 node, _ = self.scope.selected_sources.get(name) or (None, None) 960 if isinstance(node, Scope): 961 column_aliases = node.expression.alias_column_names 962 elif isinstance(node, exp.Expression): 963 column_aliases = node.alias_column_names 964 else: 965 column_aliases = [] 966 967 if column_aliases: 968 # If the source's columns are aliased, their aliases shadow the corresponding column names. 969 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 970 columns = [ 971 alias or name 972 for (name, alias) in itertools.zip_longest(columns, column_aliases) 973 ] 974 975 self._get_source_columns_cache[cache_key] = columns 976 977 return self._get_source_columns_cache[cache_key] 978 979 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 980 if self._source_columns is None: 981 self._source_columns = { 982 source_name: self.get_source_columns(source_name) 983 for source_name, source in itertools.chain( 984 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 985 ) 986 } 987 return self._source_columns 988 989 def _get_unambiguous_columns( 990 self, source_columns: t.Dict[str, t.Sequence[str]] 991 ) -> t.Mapping[str, str]: 992 """ 993 Find all the unambiguous columns in sources. 994 995 Args: 996 source_columns: Mapping of names to source columns. 997 998 Returns: 999 Mapping of column name to source name. 1000 """ 1001 if not source_columns: 1002 return {} 1003 1004 source_columns_pairs = list(source_columns.items()) 1005 1006 first_table, first_columns = source_columns_pairs[0] 1007 1008 if len(source_columns_pairs) == 1: 1009 # Performance optimization - avoid copying first_columns if there is only one table. 1010 return SingleValuedMapping(first_columns, first_table) 1011 1012 unambiguous_columns = {col: first_table for col in first_columns} 1013 all_columns = set(unambiguous_columns) 1014 1015 for table, columns in source_columns_pairs[1:]: 1016 unique = set(columns) 1017 ambiguous = all_columns.intersection(unique) 1018 all_columns.update(columns) 1019 1020 for column in ambiguous: 1021 unambiguous_columns.pop(column, None) 1022 for column in unique.difference(ambiguous): 1023 unambiguous_columns[column] = table 1024 1025 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
842 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 843 self.scope = scope 844 self.schema = schema 845 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 846 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 847 self._all_columns: t.Optional[t.Set[str]] = None 848 self._infer_schema = infer_schema 849 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
851 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 852 """ 853 Get the table for a column name. 854 855 Args: 856 column_name: The column name to find the table for. 857 Returns: 858 The table name if it can be found/inferred. 859 """ 860 if self._unambiguous_columns is None: 861 self._unambiguous_columns = self._get_unambiguous_columns( 862 self._get_all_source_columns() 863 ) 864 865 table_name = self._unambiguous_columns.get(column_name) 866 867 if not table_name and self._infer_schema: 868 sources_without_schema = tuple( 869 source 870 for source, columns in self._get_all_source_columns().items() 871 if not columns or "*" in columns 872 ) 873 if len(sources_without_schema) == 1: 874 table_name = sources_without_schema[0] 875 876 if table_name not in self.scope.selected_sources: 877 return exp.to_identifier(table_name) 878 879 node, _ = self.scope.selected_sources.get(table_name) 880 881 if isinstance(node, exp.Query): 882 while node and node.alias != table_name: 883 node = node.parent 884 885 node_alias = node.args.get("alias") 886 if node_alias: 887 return exp.to_identifier(node_alias.this) 888 889 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
891 @property 892 def all_columns(self) -> t.Set[str]: 893 """All available columns of all sources in this scope""" 894 if self._all_columns is None: 895 self._all_columns = { 896 column for columns in self._get_all_source_columns().values() for column in columns 897 } 898 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
900 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 901 """Resolve the source columns for a given source `name`.""" 902 cache_key = (name, only_visible) 903 if cache_key not in self._get_source_columns_cache: 904 if name not in self.scope.sources: 905 raise OptimizeError(f"Unknown table: {name}") 906 907 source = self.scope.sources[name] 908 909 if isinstance(source, exp.Table): 910 columns = self.schema.column_names(source, only_visible) 911 elif isinstance(source, Scope) and isinstance( 912 source.expression, (exp.Values, exp.Unnest) 913 ): 914 columns = source.expression.named_selects 915 916 # in bigquery, unnest structs are automatically scoped as tables, so you can 917 # directly select a struct field in a query. 918 # this handles the case where the unnest is statically defined. 919 if self.schema.dialect == "bigquery": 920 if source.expression.is_type(exp.DataType.Type.STRUCT): 921 for k in source.expression.type.expressions: # type: ignore 922 columns.append(k.name) 923 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 924 set_op = source.expression 925 926 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 927 on_column_list = set_op.args.get("on") 928 929 if on_column_list: 930 # The resulting columns are the columns in the ON clause: 931 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 932 columns = [col.name for col in on_column_list] 933 elif set_op.side or set_op.kind: 934 side = set_op.side 935 kind = set_op.kind 936 937 left = set_op.left.named_selects 938 right = set_op.right.named_selects 939 940 # We use dict.fromkeys to deduplicate keys and maintain insertion order 941 if side == "LEFT": 942 columns = left 943 elif side == "FULL": 944 columns = list(dict.fromkeys(left + right)) 945 elif kind == "INNER": 946 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 947 else: 948 columns = set_op.named_selects 949 else: 950 select = seq_get(source.expression.selects, 0) 951 952 if isinstance(select, exp.QueryTransform): 953 # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html 954 schema = select.args.get("schema") 955 columns = [c.name for c in schema.expressions] if schema else ["key", "value"] 956 else: 957 columns = source.expression.named_selects 958 959 node, _ = self.scope.selected_sources.get(name) or (None, None) 960 if isinstance(node, Scope): 961 column_aliases = node.expression.alias_column_names 962 elif isinstance(node, exp.Expression): 963 column_aliases = node.alias_column_names 964 else: 965 column_aliases = [] 966 967 if column_aliases: 968 # If the source's columns are aliased, their aliases shadow the corresponding column names. 969 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 970 columns = [ 971 alias or name 972 for (name, alias) in itertools.zip_longest(columns, column_aliases) 973 ] 974 975 self._get_source_columns_cache[cache_key] = columns 976 977 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.