sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 bigquery = dialect == "bigquery" 59 60 for scope in traverse_scope(expression): 61 scope_expression = scope.expression 62 is_select = isinstance(scope_expression, exp.Select) 63 64 if is_select and scope_expression.args.get("connect"): 65 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 66 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 67 scope_expression.transform( 68 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 69 copy=False, 70 ) 71 scope.clear_cache() 72 73 resolver = Resolver(scope, schema, infer_schema=infer_schema) 74 _pop_table_column_aliases(scope.ctes) 75 _pop_table_column_aliases(scope.derived_tables) 76 using_column_tables = _expand_using(scope, resolver) 77 78 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 79 _expand_alias_refs( 80 scope, 81 resolver, 82 dialect, 83 expand_only_groupby=bigquery, 84 ) 85 86 _convert_columns_to_dots(scope, resolver) 87 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 88 89 if not schema.empty and expand_alias_refs: 90 _expand_alias_refs(scope, resolver, dialect) 91 92 if is_select: 93 if expand_stars: 94 _expand_stars( 95 scope, 96 resolver, 97 using_column_tables, 98 pseudocolumns, 99 annotator, 100 ) 101 qualify_outputs(scope) 102 103 _expand_group_by(scope, dialect) 104 105 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 106 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 107 _expand_order_by_and_distinct_on(scope, resolver) 108 109 if bigquery: 110 annotator.annotate_scope(scope) 111 112 return expression 113 114 115def validate_qualify_columns(expression: E) -> E: 116 """Raise an `OptimizeError` if any columns aren't qualified""" 117 all_unqualified_columns = [] 118 for scope in traverse_scope(expression): 119 if isinstance(scope.expression, exp.Select): 120 unqualified_columns = scope.unqualified_columns 121 122 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 123 column = scope.external_columns[0] 124 for_table = f" for table: '{column.table}'" if column.table else "" 125 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 126 127 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 128 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 129 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 130 # this list here to ensure those in the former category will be excluded. 131 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 132 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 133 134 all_unqualified_columns.extend(unqualified_columns) 135 136 if all_unqualified_columns: 137 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 138 139 return expression 140 141 142def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 143 name_columns = [ 144 field.this 145 for field in unpivot.fields 146 if isinstance(field, exp.In) and isinstance(field.this, exp.Column) 147 ] 148 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 149 150 return itertools.chain(name_columns, value_columns) 151 152 153def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 154 """ 155 Remove table column aliases. 156 157 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 158 """ 159 for derived_table in derived_tables: 160 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 161 continue 162 table_alias = derived_table.args.get("alias") 163 if table_alias: 164 table_alias.args.pop("columns", None) 165 166 167def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 168 columns = {} 169 170 def _update_source_columns(source_name: str) -> None: 171 for column_name in resolver.get_source_columns(source_name): 172 if column_name not in columns: 173 columns[column_name] = source_name 174 175 joins = list(scope.find_all(exp.Join)) 176 names = {join.alias_or_name for join in joins} 177 ordered = [key for key in scope.selected_sources if key not in names] 178 179 if names and not ordered: 180 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 181 182 # Mapping of automatically joined column names to an ordered set of source names (dict). 183 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 184 185 for source_name in ordered: 186 _update_source_columns(source_name) 187 188 for i, join in enumerate(joins): 189 source_table = ordered[-1] 190 if source_table: 191 _update_source_columns(source_table) 192 193 join_table = join.alias_or_name 194 ordered.append(join_table) 195 196 using = join.args.get("using") 197 if not using: 198 continue 199 200 join_columns = resolver.get_source_columns(join_table) 201 conditions = [] 202 using_identifier_count = len(using) 203 is_semi_or_anti_join = join.is_semi_or_anti_join 204 205 for identifier in using: 206 identifier = identifier.name 207 table = columns.get(identifier) 208 209 if not table or identifier not in join_columns: 210 if (columns and "*" not in columns) and join_columns: 211 raise OptimizeError(f"Cannot automatically join: {identifier}") 212 213 table = table or source_table 214 215 if i == 0 or using_identifier_count == 1: 216 lhs: exp.Expression = exp.column(identifier, table=table) 217 else: 218 coalesce_columns = [ 219 exp.column(identifier, table=t) 220 for t in ordered[:-1] 221 if identifier in resolver.get_source_columns(t) 222 ] 223 if len(coalesce_columns) > 1: 224 lhs = exp.func("coalesce", *coalesce_columns) 225 else: 226 lhs = exp.column(identifier, table=table) 227 228 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 229 230 # Set all values in the dict to None, because we only care about the key ordering 231 tables = column_tables.setdefault(identifier, {}) 232 233 # Do not update the dict if this was a SEMI/ANTI join in 234 # order to avoid generating COALESCE columns for this join pair 235 if not is_semi_or_anti_join: 236 if table not in tables: 237 tables[table] = None 238 if join_table not in tables: 239 tables[join_table] = None 240 241 join.args.pop("using") 242 join.set("on", exp.and_(*conditions, copy=False)) 243 244 if column_tables: 245 for column in scope.columns: 246 if not column.table and column.name in column_tables: 247 tables = column_tables[column.name] 248 coalesce_args = [exp.column(column.name, table=table) for table in tables] 249 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 250 251 if isinstance(column.parent, exp.Select): 252 # Ensure the USING column keeps its name if it's projected 253 replacement = alias(replacement, alias=column.name, copy=False) 254 elif isinstance(column.parent, exp.Struct): 255 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 256 replacement = exp.PropertyEQ( 257 this=exp.to_identifier(column.name), expression=replacement 258 ) 259 260 scope.replace(column, replacement) 261 262 return column_tables 263 264 265def _expand_alias_refs( 266 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 267) -> None: 268 """ 269 Expand references to aliases. 270 Example: 271 SELECT y.foo AS bar, bar * 2 AS baz FROM y 272 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 273 """ 274 expression = scope.expression 275 276 if not isinstance(expression, exp.Select) or dialect == "oracle": 277 return 278 279 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 280 projections = {s.alias_or_name for s in expression.selects} 281 282 def replace_columns( 283 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 284 ) -> None: 285 is_group_by = isinstance(node, exp.Group) 286 is_having = isinstance(node, exp.Having) 287 if not node or (expand_only_groupby and not is_group_by): 288 return 289 290 for column in walk_in_scope(node, prune=lambda node: node.is_star): 291 if not isinstance(column, exp.Column): 292 continue 293 294 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 295 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 296 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 297 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 298 if expand_only_groupby and is_group_by and column.parent is not node: 299 continue 300 301 skip_replace = False 302 table = resolver.get_table(column.name) if resolve_table and not column.table else None 303 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 304 305 if alias_expr: 306 skip_replace = bool( 307 alias_expr.find(exp.AggFunc) 308 and column.find_ancestor(exp.AggFunc) 309 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 310 ) 311 312 # BigQuery's having clause gets confused if an alias matches a source. 313 # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; 314 # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table 315 if is_having and dialect == "bigquery": 316 skip_replace = skip_replace or any( 317 node.parts[0].name in projections 318 for node in alias_expr.find_all(exp.Column) 319 ) 320 321 if table and (not alias_expr or skip_replace): 322 column.set("table", table) 323 elif not column.table and alias_expr and not skip_replace: 324 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 325 if literal_index: 326 column.replace(exp.Literal.number(i)) 327 else: 328 column = column.replace(exp.paren(alias_expr)) 329 simplified = simplify_parens(column) 330 if simplified is not column: 331 column.replace(simplified) 332 333 for i, projection in enumerate(expression.selects): 334 replace_columns(projection) 335 if isinstance(projection, exp.Alias): 336 alias_to_expression[projection.alias] = (projection.this, i + 1) 337 338 parent_scope = scope 339 while parent_scope.is_union: 340 parent_scope = parent_scope.parent 341 342 # We shouldn't expand aliases if they match the recursive CTE's columns 343 if parent_scope.is_cte: 344 cte = parent_scope.expression.parent 345 if cte.find_ancestor(exp.With).recursive: 346 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 347 alias_to_expression.pop(recursive_cte_column.output_name, None) 348 349 replace_columns(expression.args.get("where")) 350 replace_columns(expression.args.get("group"), literal_index=True) 351 replace_columns(expression.args.get("having"), resolve_table=True) 352 replace_columns(expression.args.get("qualify"), resolve_table=True) 353 354 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 355 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 356 if dialect == "snowflake": 357 for join in expression.args.get("joins") or []: 358 replace_columns(join) 359 360 scope.clear_cache() 361 362 363def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 364 expression = scope.expression 365 group = expression.args.get("group") 366 if not group: 367 return 368 369 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 370 expression.set("group", group) 371 372 373def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 374 for modifier_key in ("order", "distinct"): 375 modifier = scope.expression.args.get(modifier_key) 376 if isinstance(modifier, exp.Distinct): 377 modifier = modifier.args.get("on") 378 379 if not isinstance(modifier, exp.Expression): 380 continue 381 382 modifier_expressions = modifier.expressions 383 if modifier_key == "order": 384 modifier_expressions = [ordered.this for ordered in modifier_expressions] 385 386 for original, expanded in zip( 387 modifier_expressions, 388 _expand_positional_references( 389 scope, modifier_expressions, resolver.schema.dialect, alias=True 390 ), 391 ): 392 for agg in original.find_all(exp.AggFunc): 393 for col in agg.find_all(exp.Column): 394 if not col.table: 395 col.set("table", resolver.get_table(col.name)) 396 397 original.replace(expanded) 398 399 if scope.expression.args.get("group"): 400 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 401 402 for expression in modifier_expressions: 403 expression.replace( 404 exp.to_identifier(_select_by_pos(scope, expression).alias) 405 if expression.is_int 406 else selects.get(expression, expression) 407 ) 408 409 410def _expand_positional_references( 411 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 412) -> t.List[exp.Expression]: 413 new_nodes: t.List[exp.Expression] = [] 414 ambiguous_projections = None 415 416 for node in expressions: 417 if node.is_int: 418 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 419 420 if alias: 421 new_nodes.append(exp.column(select.args["alias"].copy())) 422 else: 423 select = select.this 424 425 if dialect == "bigquery": 426 if ambiguous_projections is None: 427 # When a projection name is also a source name and it is referenced in the 428 # GROUP BY clause, BQ can't understand what the identifier corresponds to 429 ambiguous_projections = { 430 s.alias_or_name 431 for s in scope.expression.selects 432 if s.alias_or_name in scope.selected_sources 433 } 434 435 ambiguous = any( 436 column.parts[0].name in ambiguous_projections 437 for column in select.find_all(exp.Column) 438 ) 439 else: 440 ambiguous = False 441 442 if ( 443 isinstance(select, exp.CONSTANTS) 444 or select.find(exp.Explode, exp.Unnest) 445 or ambiguous 446 ): 447 new_nodes.append(node) 448 else: 449 new_nodes.append(select.copy()) 450 else: 451 new_nodes.append(node) 452 453 return new_nodes 454 455 456def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 457 try: 458 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 459 except IndexError: 460 raise OptimizeError(f"Unknown output column: {node.name}") 461 462 463def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 464 """ 465 Converts `Column` instances that represent struct field lookup into chained `Dots`. 466 467 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 468 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 469 """ 470 converted = False 471 for column in itertools.chain(scope.columns, scope.stars): 472 if isinstance(column, exp.Dot): 473 continue 474 475 column_table: t.Optional[str | exp.Identifier] = column.table 476 if ( 477 column_table 478 and column_table not in scope.sources 479 and ( 480 not scope.parent 481 or column_table not in scope.parent.sources 482 or not scope.is_correlated_subquery 483 ) 484 ): 485 root, *parts = column.parts 486 487 if root.name in scope.sources: 488 # The struct is already qualified, but we still need to change the AST 489 column_table = root 490 root, *parts = parts 491 else: 492 column_table = resolver.get_table(root.name) 493 494 if column_table: 495 converted = True 496 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 497 498 if converted: 499 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 500 # a `for column in scope.columns` iteration, even though they shouldn't be 501 scope.clear_cache() 502 503 504def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 505 """Disambiguate columns, ensuring each column specifies a source""" 506 for column in scope.columns: 507 column_table = column.table 508 column_name = column.name 509 510 if column_table and column_table in scope.sources: 511 source_columns = resolver.get_source_columns(column_table) 512 if ( 513 not allow_partial_qualification 514 and source_columns 515 and column_name not in source_columns 516 and "*" not in source_columns 517 ): 518 raise OptimizeError(f"Unknown column: {column_name}") 519 520 if not column_table: 521 if scope.pivots and not column.find_ancestor(exp.Pivot): 522 # If the column is under the Pivot expression, we need to qualify it 523 # using the name of the pivoted source instead of the pivot's alias 524 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 525 continue 526 527 # column_table can be a '' because bigquery unnest has no table alias 528 column_table = resolver.get_table(column_name) 529 if column_table: 530 column.set("table", column_table) 531 532 for pivot in scope.pivots: 533 for column in pivot.find_all(exp.Column): 534 if not column.table and column.name in resolver.all_columns: 535 column_table = resolver.get_table(column.name) 536 if column_table: 537 column.set("table", column_table) 538 539 540def _expand_struct_stars( 541 expression: exp.Dot, 542) -> t.List[exp.Alias]: 543 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 544 545 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 546 if not dot_column.is_type(exp.DataType.Type.STRUCT): 547 return [] 548 549 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 550 dot_column = dot_column.copy() 551 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 552 553 # First part is the table name and last part is the star so they can be dropped 554 dot_parts = expression.parts[1:-1] 555 556 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 557 for part in dot_parts[1:]: 558 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 559 # Unable to expand star unless all fields are named 560 if not isinstance(field.this, exp.Identifier): 561 return [] 562 563 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 564 starting_struct = field 565 break 566 else: 567 # There is no matching field in the struct 568 return [] 569 570 taken_names = set() 571 new_selections = [] 572 573 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 574 name = field.name 575 576 # Ambiguous or anonymous fields can't be expanded 577 if name in taken_names or not isinstance(field.this, exp.Identifier): 578 return [] 579 580 taken_names.add(name) 581 582 this = field.this.copy() 583 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 584 new_column = exp.column( 585 t.cast(exp.Identifier, root), 586 table=dot_column.args.get("table"), 587 fields=t.cast(t.List[exp.Identifier], parts), 588 ) 589 new_selections.append(alias(new_column, this, copy=False)) 590 591 return new_selections 592 593 594def _expand_stars( 595 scope: Scope, 596 resolver: Resolver, 597 using_column_tables: t.Dict[str, t.Any], 598 pseudocolumns: t.Set[str], 599 annotator: TypeAnnotator, 600) -> None: 601 """Expand stars to lists of column selections""" 602 603 new_selections: t.List[exp.Expression] = [] 604 except_columns: t.Dict[int, t.Set[str]] = {} 605 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 606 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 607 608 coalesced_columns = set() 609 dialect = resolver.schema.dialect 610 611 pivot_output_columns = None 612 pivot_exclude_columns: t.Set[str] = set() 613 614 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 615 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 616 if pivot.unpivot: 617 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 618 619 for field in pivot.fields: 620 if isinstance(field, exp.In): 621 pivot_exclude_columns.update( 622 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 623 ) 624 625 else: 626 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 627 628 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 629 if not pivot_output_columns: 630 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 631 632 is_bigquery = dialect == "bigquery" 633 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 634 # Found struct expansion, annotate scope ahead of time 635 annotator.annotate_scope(scope) 636 637 for expression in scope.expression.selects: 638 tables = [] 639 if isinstance(expression, exp.Star): 640 tables.extend(scope.selected_sources) 641 _add_except_columns(expression, tables, except_columns) 642 _add_replace_columns(expression, tables, replace_columns) 643 _add_rename_columns(expression, tables, rename_columns) 644 elif expression.is_star: 645 if not isinstance(expression, exp.Dot): 646 tables.append(expression.table) 647 _add_except_columns(expression.this, tables, except_columns) 648 _add_replace_columns(expression.this, tables, replace_columns) 649 _add_rename_columns(expression.this, tables, rename_columns) 650 elif is_bigquery: 651 struct_fields = _expand_struct_stars(expression) 652 if struct_fields: 653 new_selections.extend(struct_fields) 654 continue 655 656 if not tables: 657 new_selections.append(expression) 658 continue 659 660 for table in tables: 661 if table not in scope.sources: 662 raise OptimizeError(f"Unknown table: {table}") 663 664 columns = resolver.get_source_columns(table, only_visible=True) 665 columns = columns or scope.outer_columns 666 667 if pseudocolumns: 668 columns = [name for name in columns if name.upper() not in pseudocolumns] 669 670 if not columns or "*" in columns: 671 return 672 673 table_id = id(table) 674 columns_to_exclude = except_columns.get(table_id) or set() 675 renamed_columns = rename_columns.get(table_id, {}) 676 replaced_columns = replace_columns.get(table_id, {}) 677 678 if pivot: 679 if pivot_output_columns and pivot_exclude_columns: 680 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 681 pivot_columns.extend(pivot_output_columns) 682 else: 683 pivot_columns = pivot.alias_column_names 684 685 if pivot_columns: 686 new_selections.extend( 687 alias(exp.column(name, table=pivot.alias), name, copy=False) 688 for name in pivot_columns 689 if name not in columns_to_exclude 690 ) 691 continue 692 693 for name in columns: 694 if name in columns_to_exclude or name in coalesced_columns: 695 continue 696 if name in using_column_tables and table in using_column_tables[name]: 697 coalesced_columns.add(name) 698 tables = using_column_tables[name] 699 coalesce_args = [exp.column(name, table=table) for table in tables] 700 701 new_selections.append( 702 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 703 ) 704 else: 705 alias_ = renamed_columns.get(name, name) 706 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 707 new_selections.append( 708 alias(selection_expr, alias_, copy=False) 709 if alias_ != name 710 else selection_expr 711 ) 712 713 # Ensures we don't overwrite the initial selections with an empty list 714 if new_selections and isinstance(scope.expression, exp.Select): 715 scope.expression.set("expressions", new_selections) 716 717 718def _add_except_columns( 719 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 720) -> None: 721 except_ = expression.args.get("except") 722 723 if not except_: 724 return 725 726 columns = {e.name for e in except_} 727 728 for table in tables: 729 except_columns[id(table)] = columns 730 731 732def _add_rename_columns( 733 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 734) -> None: 735 rename = expression.args.get("rename") 736 737 if not rename: 738 return 739 740 columns = {e.this.name: e.alias for e in rename} 741 742 for table in tables: 743 rename_columns[id(table)] = columns 744 745 746def _add_replace_columns( 747 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 748) -> None: 749 replace = expression.args.get("replace") 750 751 if not replace: 752 return 753 754 columns = {e.alias: e for e in replace} 755 756 for table in tables: 757 replace_columns[id(table)] = columns 758 759 760def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 761 """Ensure all output columns are aliased""" 762 if isinstance(scope_or_expression, exp.Expression): 763 scope = build_scope(scope_or_expression) 764 if not isinstance(scope, Scope): 765 return 766 else: 767 scope = scope_or_expression 768 769 new_selections = [] 770 for i, (selection, aliased_column) in enumerate( 771 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 772 ): 773 if selection is None: 774 break 775 776 if isinstance(selection, exp.Subquery): 777 if not selection.output_name: 778 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 779 elif not isinstance(selection, exp.Alias) and not selection.is_star: 780 selection = alias( 781 selection, 782 alias=selection.output_name or f"_col_{i}", 783 copy=False, 784 ) 785 if aliased_column: 786 selection.set("alias", exp.to_identifier(aliased_column)) 787 788 new_selections.append(selection) 789 790 if isinstance(scope.expression, exp.Select): 791 scope.expression.set("expressions", new_selections) 792 793 794def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 795 """Makes sure all identifiers that need to be quoted are quoted.""" 796 return expression.transform( 797 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 798 ) # type: ignore 799 800 801def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 802 """ 803 Pushes down the CTE alias columns into the projection, 804 805 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 806 807 Example: 808 >>> import sqlglot 809 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 810 >>> pushdown_cte_alias_columns(expression).sql() 811 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 812 813 Args: 814 expression: Expression to pushdown. 815 816 Returns: 817 The expression with the CTE aliases pushed down into the projection. 818 """ 819 for cte in expression.find_all(exp.CTE): 820 if cte.alias_column_names: 821 new_expressions = [] 822 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 823 if isinstance(projection, exp.Alias): 824 projection.set("alias", _alias) 825 else: 826 projection = alias(projection, alias=_alias) 827 new_expressions.append(projection) 828 cte.this.set("expressions", new_expressions) 829 830 return expression 831 832 833class Resolver: 834 """ 835 Helper for resolving columns. 836 837 This is a class so we can lazily load some things and easily share them across functions. 838 """ 839 840 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 841 self.scope = scope 842 self.schema = schema 843 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 844 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 845 self._all_columns: t.Optional[t.Set[str]] = None 846 self._infer_schema = infer_schema 847 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 848 849 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 850 """ 851 Get the table for a column name. 852 853 Args: 854 column_name: The column name to find the table for. 855 Returns: 856 The table name if it can be found/inferred. 857 """ 858 if self._unambiguous_columns is None: 859 self._unambiguous_columns = self._get_unambiguous_columns( 860 self._get_all_source_columns() 861 ) 862 863 table_name = self._unambiguous_columns.get(column_name) 864 865 if not table_name and self._infer_schema: 866 sources_without_schema = tuple( 867 source 868 for source, columns in self._get_all_source_columns().items() 869 if not columns or "*" in columns 870 ) 871 if len(sources_without_schema) == 1: 872 table_name = sources_without_schema[0] 873 874 if table_name not in self.scope.selected_sources: 875 return exp.to_identifier(table_name) 876 877 node, _ = self.scope.selected_sources.get(table_name) 878 879 if isinstance(node, exp.Query): 880 while node and node.alias != table_name: 881 node = node.parent 882 883 node_alias = node.args.get("alias") 884 if node_alias: 885 return exp.to_identifier(node_alias.this) 886 887 return exp.to_identifier(table_name) 888 889 @property 890 def all_columns(self) -> t.Set[str]: 891 """All available columns of all sources in this scope""" 892 if self._all_columns is None: 893 self._all_columns = { 894 column for columns in self._get_all_source_columns().values() for column in columns 895 } 896 return self._all_columns 897 898 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 899 """Resolve the source columns for a given source `name`.""" 900 cache_key = (name, only_visible) 901 if cache_key not in self._get_source_columns_cache: 902 if name not in self.scope.sources: 903 raise OptimizeError(f"Unknown table: {name}") 904 905 source = self.scope.sources[name] 906 907 if isinstance(source, exp.Table): 908 columns = self.schema.column_names(source, only_visible) 909 elif isinstance(source, Scope) and isinstance( 910 source.expression, (exp.Values, exp.Unnest) 911 ): 912 columns = source.expression.named_selects 913 914 # in bigquery, unnest structs are automatically scoped as tables, so you can 915 # directly select a struct field in a query. 916 # this handles the case where the unnest is statically defined. 917 if self.schema.dialect == "bigquery": 918 if source.expression.is_type(exp.DataType.Type.STRUCT): 919 for k in source.expression.type.expressions: # type: ignore 920 columns.append(k.name) 921 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 922 set_op = source.expression 923 924 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 925 on_column_list = set_op.args.get("on") 926 927 if on_column_list: 928 # The resulting columns are the columns in the ON clause: 929 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 930 columns = [col.name for col in on_column_list] 931 elif set_op.side or set_op.kind: 932 side = set_op.side 933 kind = set_op.kind 934 935 left = set_op.left.named_selects 936 right = set_op.right.named_selects 937 938 # We use dict.fromkeys to deduplicate keys and maintain insertion order 939 if side == "LEFT": 940 columns = left 941 elif side == "FULL": 942 columns = list(dict.fromkeys(left + right)) 943 elif kind == "INNER": 944 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 945 else: 946 columns = set_op.named_selects 947 else: 948 columns = source.expression.named_selects 949 950 node, _ = self.scope.selected_sources.get(name) or (None, None) 951 if isinstance(node, Scope): 952 column_aliases = node.expression.alias_column_names 953 elif isinstance(node, exp.Expression): 954 column_aliases = node.alias_column_names 955 else: 956 column_aliases = [] 957 958 if column_aliases: 959 # If the source's columns are aliased, their aliases shadow the corresponding column names. 960 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 961 columns = [ 962 alias or name 963 for (name, alias) in itertools.zip_longest(columns, column_aliases) 964 ] 965 966 self._get_source_columns_cache[cache_key] = columns 967 968 return self._get_source_columns_cache[cache_key] 969 970 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 971 if self._source_columns is None: 972 self._source_columns = { 973 source_name: self.get_source_columns(source_name) 974 for source_name, source in itertools.chain( 975 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 976 ) 977 } 978 return self._source_columns 979 980 def _get_unambiguous_columns( 981 self, source_columns: t.Dict[str, t.Sequence[str]] 982 ) -> t.Mapping[str, str]: 983 """ 984 Find all the unambiguous columns in sources. 985 986 Args: 987 source_columns: Mapping of names to source columns. 988 989 Returns: 990 Mapping of column name to source name. 991 """ 992 if not source_columns: 993 return {} 994 995 source_columns_pairs = list(source_columns.items()) 996 997 first_table, first_columns = source_columns_pairs[0] 998 999 if len(source_columns_pairs) == 1: 1000 # Performance optimization - avoid copying first_columns if there is only one table. 1001 return SingleValuedMapping(first_columns, first_table) 1002 1003 unambiguous_columns = {col: first_table for col in first_columns} 1004 all_columns = set(unambiguous_columns) 1005 1006 for table, columns in source_columns_pairs[1:]: 1007 unique = set(columns) 1008 ambiguous = all_columns.intersection(unique) 1009 all_columns.update(columns) 1010 1011 for column in ambiguous: 1012 unambiguous_columns.pop(column, None) 1013 for column in unique.difference(ambiguous): 1014 unambiguous_columns[column] = table 1015 1016 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 bigquery = dialect == "bigquery" 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=bigquery, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if bigquery: 111 annotator.annotate_scope(scope) 112 113 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
761def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 762 """Ensure all output columns are aliased""" 763 if isinstance(scope_or_expression, exp.Expression): 764 scope = build_scope(scope_or_expression) 765 if not isinstance(scope, Scope): 766 return 767 else: 768 scope = scope_or_expression 769 770 new_selections = [] 771 for i, (selection, aliased_column) in enumerate( 772 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 773 ): 774 if selection is None: 775 break 776 777 if isinstance(selection, exp.Subquery): 778 if not selection.output_name: 779 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 780 elif not isinstance(selection, exp.Alias) and not selection.is_star: 781 selection = alias( 782 selection, 783 alias=selection.output_name or f"_col_{i}", 784 copy=False, 785 ) 786 if aliased_column: 787 selection.set("alias", exp.to_identifier(aliased_column)) 788 789 new_selections.append(selection) 790 791 if isinstance(scope.expression, exp.Select): 792 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
795def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 796 """Makes sure all identifiers that need to be quoted are quoted.""" 797 return expression.transform( 798 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 799 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
802def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 803 """ 804 Pushes down the CTE alias columns into the projection, 805 806 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 807 808 Example: 809 >>> import sqlglot 810 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 811 >>> pushdown_cte_alias_columns(expression).sql() 812 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 813 814 Args: 815 expression: Expression to pushdown. 816 817 Returns: 818 The expression with the CTE aliases pushed down into the projection. 819 """ 820 for cte in expression.find_all(exp.CTE): 821 if cte.alias_column_names: 822 new_expressions = [] 823 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 824 if isinstance(projection, exp.Alias): 825 projection.set("alias", _alias) 826 else: 827 projection = alias(projection, alias=_alias) 828 new_expressions.append(projection) 829 cte.this.set("expressions", new_expressions) 830 831 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
834class Resolver: 835 """ 836 Helper for resolving columns. 837 838 This is a class so we can lazily load some things and easily share them across functions. 839 """ 840 841 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 842 self.scope = scope 843 self.schema = schema 844 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 845 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 846 self._all_columns: t.Optional[t.Set[str]] = None 847 self._infer_schema = infer_schema 848 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 849 850 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 851 """ 852 Get the table for a column name. 853 854 Args: 855 column_name: The column name to find the table for. 856 Returns: 857 The table name if it can be found/inferred. 858 """ 859 if self._unambiguous_columns is None: 860 self._unambiguous_columns = self._get_unambiguous_columns( 861 self._get_all_source_columns() 862 ) 863 864 table_name = self._unambiguous_columns.get(column_name) 865 866 if not table_name and self._infer_schema: 867 sources_without_schema = tuple( 868 source 869 for source, columns in self._get_all_source_columns().items() 870 if not columns or "*" in columns 871 ) 872 if len(sources_without_schema) == 1: 873 table_name = sources_without_schema[0] 874 875 if table_name not in self.scope.selected_sources: 876 return exp.to_identifier(table_name) 877 878 node, _ = self.scope.selected_sources.get(table_name) 879 880 if isinstance(node, exp.Query): 881 while node and node.alias != table_name: 882 node = node.parent 883 884 node_alias = node.args.get("alias") 885 if node_alias: 886 return exp.to_identifier(node_alias.this) 887 888 return exp.to_identifier(table_name) 889 890 @property 891 def all_columns(self) -> t.Set[str]: 892 """All available columns of all sources in this scope""" 893 if self._all_columns is None: 894 self._all_columns = { 895 column for columns in self._get_all_source_columns().values() for column in columns 896 } 897 return self._all_columns 898 899 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 900 """Resolve the source columns for a given source `name`.""" 901 cache_key = (name, only_visible) 902 if cache_key not in self._get_source_columns_cache: 903 if name not in self.scope.sources: 904 raise OptimizeError(f"Unknown table: {name}") 905 906 source = self.scope.sources[name] 907 908 if isinstance(source, exp.Table): 909 columns = self.schema.column_names(source, only_visible) 910 elif isinstance(source, Scope) and isinstance( 911 source.expression, (exp.Values, exp.Unnest) 912 ): 913 columns = source.expression.named_selects 914 915 # in bigquery, unnest structs are automatically scoped as tables, so you can 916 # directly select a struct field in a query. 917 # this handles the case where the unnest is statically defined. 918 if self.schema.dialect == "bigquery": 919 if source.expression.is_type(exp.DataType.Type.STRUCT): 920 for k in source.expression.type.expressions: # type: ignore 921 columns.append(k.name) 922 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 923 set_op = source.expression 924 925 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 926 on_column_list = set_op.args.get("on") 927 928 if on_column_list: 929 # The resulting columns are the columns in the ON clause: 930 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 931 columns = [col.name for col in on_column_list] 932 elif set_op.side or set_op.kind: 933 side = set_op.side 934 kind = set_op.kind 935 936 left = set_op.left.named_selects 937 right = set_op.right.named_selects 938 939 # We use dict.fromkeys to deduplicate keys and maintain insertion order 940 if side == "LEFT": 941 columns = left 942 elif side == "FULL": 943 columns = list(dict.fromkeys(left + right)) 944 elif kind == "INNER": 945 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 946 else: 947 columns = set_op.named_selects 948 else: 949 columns = source.expression.named_selects 950 951 node, _ = self.scope.selected_sources.get(name) or (None, None) 952 if isinstance(node, Scope): 953 column_aliases = node.expression.alias_column_names 954 elif isinstance(node, exp.Expression): 955 column_aliases = node.alias_column_names 956 else: 957 column_aliases = [] 958 959 if column_aliases: 960 # If the source's columns are aliased, their aliases shadow the corresponding column names. 961 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 962 columns = [ 963 alias or name 964 for (name, alias) in itertools.zip_longest(columns, column_aliases) 965 ] 966 967 self._get_source_columns_cache[cache_key] = columns 968 969 return self._get_source_columns_cache[cache_key] 970 971 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 972 if self._source_columns is None: 973 self._source_columns = { 974 source_name: self.get_source_columns(source_name) 975 for source_name, source in itertools.chain( 976 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 977 ) 978 } 979 return self._source_columns 980 981 def _get_unambiguous_columns( 982 self, source_columns: t.Dict[str, t.Sequence[str]] 983 ) -> t.Mapping[str, str]: 984 """ 985 Find all the unambiguous columns in sources. 986 987 Args: 988 source_columns: Mapping of names to source columns. 989 990 Returns: 991 Mapping of column name to source name. 992 """ 993 if not source_columns: 994 return {} 995 996 source_columns_pairs = list(source_columns.items()) 997 998 first_table, first_columns = source_columns_pairs[0] 999 1000 if len(source_columns_pairs) == 1: 1001 # Performance optimization - avoid copying first_columns if there is only one table. 1002 return SingleValuedMapping(first_columns, first_table) 1003 1004 unambiguous_columns = {col: first_table for col in first_columns} 1005 all_columns = set(unambiguous_columns) 1006 1007 for table, columns in source_columns_pairs[1:]: 1008 unique = set(columns) 1009 ambiguous = all_columns.intersection(unique) 1010 all_columns.update(columns) 1011 1012 for column in ambiguous: 1013 unambiguous_columns.pop(column, None) 1014 for column in unique.difference(ambiguous): 1015 unambiguous_columns[column] = table 1016 1017 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
841 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 842 self.scope = scope 843 self.schema = schema 844 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 845 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 846 self._all_columns: t.Optional[t.Set[str]] = None 847 self._infer_schema = infer_schema 848 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
850 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 851 """ 852 Get the table for a column name. 853 854 Args: 855 column_name: The column name to find the table for. 856 Returns: 857 The table name if it can be found/inferred. 858 """ 859 if self._unambiguous_columns is None: 860 self._unambiguous_columns = self._get_unambiguous_columns( 861 self._get_all_source_columns() 862 ) 863 864 table_name = self._unambiguous_columns.get(column_name) 865 866 if not table_name and self._infer_schema: 867 sources_without_schema = tuple( 868 source 869 for source, columns in self._get_all_source_columns().items() 870 if not columns or "*" in columns 871 ) 872 if len(sources_without_schema) == 1: 873 table_name = sources_without_schema[0] 874 875 if table_name not in self.scope.selected_sources: 876 return exp.to_identifier(table_name) 877 878 node, _ = self.scope.selected_sources.get(table_name) 879 880 if isinstance(node, exp.Query): 881 while node and node.alias != table_name: 882 node = node.parent 883 884 node_alias = node.args.get("alias") 885 if node_alias: 886 return exp.to_identifier(node_alias.this) 887 888 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
890 @property 891 def all_columns(self) -> t.Set[str]: 892 """All available columns of all sources in this scope""" 893 if self._all_columns is None: 894 self._all_columns = { 895 column for columns in self._get_all_source_columns().values() for column in columns 896 } 897 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
899 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 900 """Resolve the source columns for a given source `name`.""" 901 cache_key = (name, only_visible) 902 if cache_key not in self._get_source_columns_cache: 903 if name not in self.scope.sources: 904 raise OptimizeError(f"Unknown table: {name}") 905 906 source = self.scope.sources[name] 907 908 if isinstance(source, exp.Table): 909 columns = self.schema.column_names(source, only_visible) 910 elif isinstance(source, Scope) and isinstance( 911 source.expression, (exp.Values, exp.Unnest) 912 ): 913 columns = source.expression.named_selects 914 915 # in bigquery, unnest structs are automatically scoped as tables, so you can 916 # directly select a struct field in a query. 917 # this handles the case where the unnest is statically defined. 918 if self.schema.dialect == "bigquery": 919 if source.expression.is_type(exp.DataType.Type.STRUCT): 920 for k in source.expression.type.expressions: # type: ignore 921 columns.append(k.name) 922 elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): 923 set_op = source.expression 924 925 # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME 926 on_column_list = set_op.args.get("on") 927 928 if on_column_list: 929 # The resulting columns are the columns in the ON clause: 930 # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) 931 columns = [col.name for col in on_column_list] 932 elif set_op.side or set_op.kind: 933 side = set_op.side 934 kind = set_op.kind 935 936 left = set_op.left.named_selects 937 right = set_op.right.named_selects 938 939 # We use dict.fromkeys to deduplicate keys and maintain insertion order 940 if side == "LEFT": 941 columns = left 942 elif side == "FULL": 943 columns = list(dict.fromkeys(left + right)) 944 elif kind == "INNER": 945 columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) 946 else: 947 columns = set_op.named_selects 948 else: 949 columns = source.expression.named_selects 950 951 node, _ = self.scope.selected_sources.get(name) or (None, None) 952 if isinstance(node, Scope): 953 column_aliases = node.expression.alias_column_names 954 elif isinstance(node, exp.Expression): 955 column_aliases = node.alias_column_names 956 else: 957 column_aliases = [] 958 959 if column_aliases: 960 # If the source's columns are aliased, their aliases shadow the corresponding column names. 961 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 962 columns = [ 963 alias or name 964 for (name, alias) in itertools.zip_longest(columns, column_aliases) 965 ] 966 967 self._get_source_columns_cache[cache_key] = columns 968 969 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.