sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 59 snowflake_or_oracle = dialect in ("oracle", "snowflake") 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and snowflake_or_oracle and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 106 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 107 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 108 _expand_order_by_and_distinct_on(scope, resolver) 109 110 if dialect == "bigquery": 111 annotator.annotate_scope(scope) 112 113 return expression 114 115 116def validate_qualify_columns(expression: E) -> E: 117 """Raise an `OptimizeError` if any columns aren't qualified""" 118 all_unqualified_columns = [] 119 for scope in traverse_scope(expression): 120 if isinstance(scope.expression, exp.Select): 121 unqualified_columns = scope.unqualified_columns 122 123 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 124 column = scope.external_columns[0] 125 for_table = f" for table: '{column.table}'" if column.table else "" 126 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 127 128 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 129 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 130 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 131 # this list here to ensure those in the former category will be excluded. 132 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 133 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 134 135 all_unqualified_columns.extend(unqualified_columns) 136 137 if all_unqualified_columns: 138 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 139 140 return expression 141 142 143def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 144 name_column = [] 145 field = unpivot.args.get("field") 146 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 147 name_column.append(field.this) 148 149 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 150 return itertools.chain(name_column, value_columns) 151 152 153def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 154 """ 155 Remove table column aliases. 156 157 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 158 """ 159 for derived_table in derived_tables: 160 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 161 continue 162 table_alias = derived_table.args.get("alias") 163 if table_alias: 164 table_alias.args.pop("columns", None) 165 166 167def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 168 columns = {} 169 170 def _update_source_columns(source_name: str) -> None: 171 for column_name in resolver.get_source_columns(source_name): 172 if column_name not in columns: 173 columns[column_name] = source_name 174 175 joins = list(scope.find_all(exp.Join)) 176 names = {join.alias_or_name for join in joins} 177 ordered = [key for key in scope.selected_sources if key not in names] 178 179 if names and not ordered: 180 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 181 182 # Mapping of automatically joined column names to an ordered set of source names (dict). 183 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 184 185 for source_name in ordered: 186 _update_source_columns(source_name) 187 188 for i, join in enumerate(joins): 189 source_table = ordered[-1] 190 if source_table: 191 _update_source_columns(source_table) 192 193 join_table = join.alias_or_name 194 ordered.append(join_table) 195 196 using = join.args.get("using") 197 if not using: 198 continue 199 200 join_columns = resolver.get_source_columns(join_table) 201 conditions = [] 202 using_identifier_count = len(using) 203 is_semi_or_anti_join = join.is_semi_or_anti_join 204 205 for identifier in using: 206 identifier = identifier.name 207 table = columns.get(identifier) 208 209 if not table or identifier not in join_columns: 210 if (columns and "*" not in columns) and join_columns: 211 raise OptimizeError(f"Cannot automatically join: {identifier}") 212 213 table = table or source_table 214 215 if i == 0 or using_identifier_count == 1: 216 lhs: exp.Expression = exp.column(identifier, table=table) 217 else: 218 coalesce_columns = [ 219 exp.column(identifier, table=t) 220 for t in ordered[:-1] 221 if identifier in resolver.get_source_columns(t) 222 ] 223 if len(coalesce_columns) > 1: 224 lhs = exp.func("coalesce", *coalesce_columns) 225 else: 226 lhs = exp.column(identifier, table=table) 227 228 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 229 230 # Set all values in the dict to None, because we only care about the key ordering 231 tables = column_tables.setdefault(identifier, {}) 232 233 # Do not update the dict if this was a SEMI/ANTI join in 234 # order to avoid generating COALESCE columns for this join pair 235 if not is_semi_or_anti_join: 236 if table not in tables: 237 tables[table] = None 238 if join_table not in tables: 239 tables[join_table] = None 240 241 join.args.pop("using") 242 join.set("on", exp.and_(*conditions, copy=False)) 243 244 if column_tables: 245 for column in scope.columns: 246 if not column.table and column.name in column_tables: 247 tables = column_tables[column.name] 248 coalesce_args = [exp.column(column.name, table=table) for table in tables] 249 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 250 251 if isinstance(column.parent, exp.Select): 252 # Ensure the USING column keeps its name if it's projected 253 replacement = alias(replacement, alias=column.name, copy=False) 254 elif isinstance(column.parent, exp.Struct): 255 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 256 replacement = exp.PropertyEQ( 257 this=exp.to_identifier(column.name), expression=replacement 258 ) 259 260 scope.replace(column, replacement) 261 262 return column_tables 263 264 265def _expand_alias_refs( 266 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 267) -> None: 268 """ 269 Expand references to aliases. 270 Example: 271 SELECT y.foo AS bar, bar * 2 AS baz FROM y 272 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 273 """ 274 expression = scope.expression 275 276 if not isinstance(expression, exp.Select): 277 return 278 279 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 280 281 def replace_columns( 282 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 283 ) -> None: 284 is_group_by = isinstance(node, exp.Group) 285 if not node or (expand_only_groupby and not is_group_by): 286 return 287 288 for column in walk_in_scope(node, prune=lambda node: node.is_star): 289 if not isinstance(column, exp.Column): 290 continue 291 292 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 293 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 294 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 295 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 296 if expand_only_groupby and is_group_by and column.parent is not node: 297 continue 298 299 table = resolver.get_table(column.name) if resolve_table and not column.table else None 300 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 301 double_agg = ( 302 ( 303 alias_expr.find(exp.AggFunc) 304 and ( 305 column.find_ancestor(exp.AggFunc) 306 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 307 ) 308 ) 309 if alias_expr 310 else False 311 ) 312 313 if table and (not alias_expr or double_agg): 314 column.set("table", table) 315 elif not column.table and alias_expr and not double_agg: 316 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 317 if literal_index: 318 column.replace(exp.Literal.number(i)) 319 else: 320 column = column.replace(exp.paren(alias_expr)) 321 simplified = simplify_parens(column) 322 if simplified is not column: 323 column.replace(simplified) 324 325 for i, projection in enumerate(expression.selects): 326 replace_columns(projection) 327 if isinstance(projection, exp.Alias): 328 alias_to_expression[projection.alias] = (projection.this, i + 1) 329 330 parent_scope = scope 331 while parent_scope.is_union: 332 parent_scope = parent_scope.parent 333 334 # We shouldn't expand aliases if they match the recursive CTE's columns 335 if parent_scope.is_cte: 336 cte = parent_scope.expression.parent 337 if cte.find_ancestor(exp.With).recursive: 338 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 339 alias_to_expression.pop(recursive_cte_column.output_name, None) 340 341 replace_columns(expression.args.get("where")) 342 replace_columns(expression.args.get("group"), literal_index=True) 343 replace_columns(expression.args.get("having"), resolve_table=True) 344 replace_columns(expression.args.get("qualify"), resolve_table=True) 345 346 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 347 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 348 if dialect == "snowflake": 349 for join in expression.args.get("joins") or []: 350 replace_columns(join) 351 352 scope.clear_cache() 353 354 355def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 356 expression = scope.expression 357 group = expression.args.get("group") 358 if not group: 359 return 360 361 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 362 expression.set("group", group) 363 364 365def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: 366 for modifier_key in ("order", "distinct"): 367 modifier = scope.expression.args.get(modifier_key) 368 if isinstance(modifier, exp.Distinct): 369 modifier = modifier.args.get("on") 370 371 if not isinstance(modifier, exp.Expression): 372 continue 373 374 modifier_expressions = modifier.expressions 375 if modifier_key == "order": 376 modifier_expressions = [ordered.this for ordered in modifier_expressions] 377 378 for original, expanded in zip( 379 modifier_expressions, 380 _expand_positional_references( 381 scope, modifier_expressions, resolver.schema.dialect, alias=True 382 ), 383 ): 384 for agg in original.find_all(exp.AggFunc): 385 for col in agg.find_all(exp.Column): 386 if not col.table: 387 col.set("table", resolver.get_table(col.name)) 388 389 original.replace(expanded) 390 391 if scope.expression.args.get("group"): 392 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 393 394 for expression in modifier_expressions: 395 expression.replace( 396 exp.to_identifier(_select_by_pos(scope, expression).alias) 397 if expression.is_int 398 else selects.get(expression, expression) 399 ) 400 401 402def _expand_positional_references( 403 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 404) -> t.List[exp.Expression]: 405 new_nodes: t.List[exp.Expression] = [] 406 ambiguous_projections = None 407 408 for node in expressions: 409 if node.is_int: 410 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 411 412 if alias: 413 new_nodes.append(exp.column(select.args["alias"].copy())) 414 else: 415 select = select.this 416 417 if dialect == "bigquery": 418 if ambiguous_projections is None: 419 # When a projection name is also a source name and it is referenced in the 420 # GROUP BY clause, BQ can't understand what the identifier corresponds to 421 ambiguous_projections = { 422 s.alias_or_name 423 for s in scope.expression.selects 424 if s.alias_or_name in scope.selected_sources 425 } 426 427 ambiguous = any( 428 column.parts[0].name in ambiguous_projections 429 for column in select.find_all(exp.Column) 430 ) 431 else: 432 ambiguous = False 433 434 if ( 435 isinstance(select, exp.CONSTANTS) 436 or select.find(exp.Explode, exp.Unnest) 437 or ambiguous 438 ): 439 new_nodes.append(node) 440 else: 441 new_nodes.append(select.copy()) 442 else: 443 new_nodes.append(node) 444 445 return new_nodes 446 447 448def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 449 try: 450 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 451 except IndexError: 452 raise OptimizeError(f"Unknown output column: {node.name}") 453 454 455def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 456 """ 457 Converts `Column` instances that represent struct field lookup into chained `Dots`. 458 459 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 460 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 461 """ 462 converted = False 463 for column in itertools.chain(scope.columns, scope.stars): 464 if isinstance(column, exp.Dot): 465 continue 466 467 column_table: t.Optional[str | exp.Identifier] = column.table 468 if ( 469 column_table 470 and column_table not in scope.sources 471 and ( 472 not scope.parent 473 or column_table not in scope.parent.sources 474 or not scope.is_correlated_subquery 475 ) 476 ): 477 root, *parts = column.parts 478 479 if root.name in scope.sources: 480 # The struct is already qualified, but we still need to change the AST 481 column_table = root 482 root, *parts = parts 483 else: 484 column_table = resolver.get_table(root.name) 485 486 if column_table: 487 converted = True 488 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 489 490 if converted: 491 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 492 # a `for column in scope.columns` iteration, even though they shouldn't be 493 scope.clear_cache() 494 495 496def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 497 """Disambiguate columns, ensuring each column specifies a source""" 498 for column in scope.columns: 499 column_table = column.table 500 column_name = column.name 501 502 if column_table and column_table in scope.sources: 503 source_columns = resolver.get_source_columns(column_table) 504 if ( 505 not allow_partial_qualification 506 and source_columns 507 and column_name not in source_columns 508 and "*" not in source_columns 509 ): 510 raise OptimizeError(f"Unknown column: {column_name}") 511 512 if not column_table: 513 if scope.pivots and not column.find_ancestor(exp.Pivot): 514 # If the column is under the Pivot expression, we need to qualify it 515 # using the name of the pivoted source instead of the pivot's alias 516 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 517 continue 518 519 # column_table can be a '' because bigquery unnest has no table alias 520 column_table = resolver.get_table(column_name) 521 if column_table: 522 column.set("table", column_table) 523 524 for pivot in scope.pivots: 525 for column in pivot.find_all(exp.Column): 526 if not column.table and column.name in resolver.all_columns: 527 column_table = resolver.get_table(column.name) 528 if column_table: 529 column.set("table", column_table) 530 531 532def _expand_struct_stars( 533 expression: exp.Dot, 534) -> t.List[exp.Alias]: 535 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 536 537 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 538 if not dot_column.is_type(exp.DataType.Type.STRUCT): 539 return [] 540 541 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 542 dot_column = dot_column.copy() 543 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 544 545 # First part is the table name and last part is the star so they can be dropped 546 dot_parts = expression.parts[1:-1] 547 548 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 549 for part in dot_parts[1:]: 550 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 551 # Unable to expand star unless all fields are named 552 if not isinstance(field.this, exp.Identifier): 553 return [] 554 555 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 556 starting_struct = field 557 break 558 else: 559 # There is no matching field in the struct 560 return [] 561 562 taken_names = set() 563 new_selections = [] 564 565 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 566 name = field.name 567 568 # Ambiguous or anonymous fields can't be expanded 569 if name in taken_names or not isinstance(field.this, exp.Identifier): 570 return [] 571 572 taken_names.add(name) 573 574 this = field.this.copy() 575 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 576 new_column = exp.column( 577 t.cast(exp.Identifier, root), 578 table=dot_column.args.get("table"), 579 fields=t.cast(t.List[exp.Identifier], parts), 580 ) 581 new_selections.append(alias(new_column, this, copy=False)) 582 583 return new_selections 584 585 586def _expand_stars( 587 scope: Scope, 588 resolver: Resolver, 589 using_column_tables: t.Dict[str, t.Any], 590 pseudocolumns: t.Set[str], 591 annotator: TypeAnnotator, 592) -> None: 593 """Expand stars to lists of column selections""" 594 595 new_selections: t.List[exp.Expression] = [] 596 except_columns: t.Dict[int, t.Set[str]] = {} 597 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 598 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 599 600 coalesced_columns = set() 601 dialect = resolver.schema.dialect 602 603 pivot_output_columns = None 604 pivot_exclude_columns = None 605 606 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 607 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 608 if pivot.unpivot: 609 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 610 611 field = pivot.args.get("field") 612 if isinstance(field, exp.In): 613 pivot_exclude_columns = { 614 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 615 } 616 else: 617 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 618 619 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 620 if not pivot_output_columns: 621 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 622 623 is_bigquery = dialect == "bigquery" 624 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 625 # Found struct expansion, annotate scope ahead of time 626 annotator.annotate_scope(scope) 627 628 for expression in scope.expression.selects: 629 tables = [] 630 if isinstance(expression, exp.Star): 631 tables.extend(scope.selected_sources) 632 _add_except_columns(expression, tables, except_columns) 633 _add_replace_columns(expression, tables, replace_columns) 634 _add_rename_columns(expression, tables, rename_columns) 635 elif expression.is_star: 636 if not isinstance(expression, exp.Dot): 637 tables.append(expression.table) 638 _add_except_columns(expression.this, tables, except_columns) 639 _add_replace_columns(expression.this, tables, replace_columns) 640 _add_rename_columns(expression.this, tables, rename_columns) 641 elif is_bigquery: 642 struct_fields = _expand_struct_stars(expression) 643 if struct_fields: 644 new_selections.extend(struct_fields) 645 continue 646 647 if not tables: 648 new_selections.append(expression) 649 continue 650 651 for table in tables: 652 if table not in scope.sources: 653 raise OptimizeError(f"Unknown table: {table}") 654 655 columns = resolver.get_source_columns(table, only_visible=True) 656 columns = columns or scope.outer_columns 657 658 if pseudocolumns: 659 columns = [name for name in columns if name.upper() not in pseudocolumns] 660 661 if not columns or "*" in columns: 662 return 663 664 table_id = id(table) 665 columns_to_exclude = except_columns.get(table_id) or set() 666 renamed_columns = rename_columns.get(table_id, {}) 667 replaced_columns = replace_columns.get(table_id, {}) 668 669 if pivot: 670 if pivot_output_columns and pivot_exclude_columns: 671 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 672 pivot_columns.extend(pivot_output_columns) 673 else: 674 pivot_columns = pivot.alias_column_names 675 676 if pivot_columns: 677 new_selections.extend( 678 alias(exp.column(name, table=pivot.alias), name, copy=False) 679 for name in pivot_columns 680 if name not in columns_to_exclude 681 ) 682 continue 683 684 for name in columns: 685 if name in columns_to_exclude or name in coalesced_columns: 686 continue 687 if name in using_column_tables and table in using_column_tables[name]: 688 coalesced_columns.add(name) 689 tables = using_column_tables[name] 690 coalesce_args = [exp.column(name, table=table) for table in tables] 691 692 new_selections.append( 693 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 694 ) 695 else: 696 alias_ = renamed_columns.get(name, name) 697 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 698 new_selections.append( 699 alias(selection_expr, alias_, copy=False) 700 if alias_ != name 701 else selection_expr 702 ) 703 704 # Ensures we don't overwrite the initial selections with an empty list 705 if new_selections and isinstance(scope.expression, exp.Select): 706 scope.expression.set("expressions", new_selections) 707 708 709def _add_except_columns( 710 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 711) -> None: 712 except_ = expression.args.get("except") 713 714 if not except_: 715 return 716 717 columns = {e.name for e in except_} 718 719 for table in tables: 720 except_columns[id(table)] = columns 721 722 723def _add_rename_columns( 724 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 725) -> None: 726 rename = expression.args.get("rename") 727 728 if not rename: 729 return 730 731 columns = {e.this.name: e.alias for e in rename} 732 733 for table in tables: 734 rename_columns[id(table)] = columns 735 736 737def _add_replace_columns( 738 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 739) -> None: 740 replace = expression.args.get("replace") 741 742 if not replace: 743 return 744 745 columns = {e.alias: e for e in replace} 746 747 for table in tables: 748 replace_columns[id(table)] = columns 749 750 751def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 752 """Ensure all output columns are aliased""" 753 if isinstance(scope_or_expression, exp.Expression): 754 scope = build_scope(scope_or_expression) 755 if not isinstance(scope, Scope): 756 return 757 else: 758 scope = scope_or_expression 759 760 new_selections = [] 761 for i, (selection, aliased_column) in enumerate( 762 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 763 ): 764 if selection is None: 765 break 766 767 if isinstance(selection, exp.Subquery): 768 if not selection.output_name: 769 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 770 elif not isinstance(selection, exp.Alias) and not selection.is_star: 771 selection = alias( 772 selection, 773 alias=selection.output_name or f"_col_{i}", 774 copy=False, 775 ) 776 if aliased_column: 777 selection.set("alias", exp.to_identifier(aliased_column)) 778 779 new_selections.append(selection) 780 781 if isinstance(scope.expression, exp.Select): 782 scope.expression.set("expressions", new_selections) 783 784 785def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 786 """Makes sure all identifiers that need to be quoted are quoted.""" 787 return expression.transform( 788 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 789 ) # type: ignore 790 791 792def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 793 """ 794 Pushes down the CTE alias columns into the projection, 795 796 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 797 798 Example: 799 >>> import sqlglot 800 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 801 >>> pushdown_cte_alias_columns(expression).sql() 802 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 803 804 Args: 805 expression: Expression to pushdown. 806 807 Returns: 808 The expression with the CTE aliases pushed down into the projection. 809 """ 810 for cte in expression.find_all(exp.CTE): 811 if cte.alias_column_names: 812 new_expressions = [] 813 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 814 if isinstance(projection, exp.Alias): 815 projection.set("alias", _alias) 816 else: 817 projection = alias(projection, alias=_alias) 818 new_expressions.append(projection) 819 cte.this.set("expressions", new_expressions) 820 821 return expression 822 823 824class Resolver: 825 """ 826 Helper for resolving columns. 827 828 This is a class so we can lazily load some things and easily share them across functions. 829 """ 830 831 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 832 self.scope = scope 833 self.schema = schema 834 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 835 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 836 self._all_columns: t.Optional[t.Set[str]] = None 837 self._infer_schema = infer_schema 838 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 839 840 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 841 """ 842 Get the table for a column name. 843 844 Args: 845 column_name: The column name to find the table for. 846 Returns: 847 The table name if it can be found/inferred. 848 """ 849 if self._unambiguous_columns is None: 850 self._unambiguous_columns = self._get_unambiguous_columns( 851 self._get_all_source_columns() 852 ) 853 854 table_name = self._unambiguous_columns.get(column_name) 855 856 if not table_name and self._infer_schema: 857 sources_without_schema = tuple( 858 source 859 for source, columns in self._get_all_source_columns().items() 860 if not columns or "*" in columns 861 ) 862 if len(sources_without_schema) == 1: 863 table_name = sources_without_schema[0] 864 865 if table_name not in self.scope.selected_sources: 866 return exp.to_identifier(table_name) 867 868 node, _ = self.scope.selected_sources.get(table_name) 869 870 if isinstance(node, exp.Query): 871 while node and node.alias != table_name: 872 node = node.parent 873 874 node_alias = node.args.get("alias") 875 if node_alias: 876 return exp.to_identifier(node_alias.this) 877 878 return exp.to_identifier(table_name) 879 880 @property 881 def all_columns(self) -> t.Set[str]: 882 """All available columns of all sources in this scope""" 883 if self._all_columns is None: 884 self._all_columns = { 885 column for columns in self._get_all_source_columns().values() for column in columns 886 } 887 return self._all_columns 888 889 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 890 """Resolve the source columns for a given source `name`.""" 891 cache_key = (name, only_visible) 892 if cache_key not in self._get_source_columns_cache: 893 if name not in self.scope.sources: 894 raise OptimizeError(f"Unknown table: {name}") 895 896 source = self.scope.sources[name] 897 898 if isinstance(source, exp.Table): 899 columns = self.schema.column_names(source, only_visible) 900 elif isinstance(source, Scope) and isinstance( 901 source.expression, (exp.Values, exp.Unnest) 902 ): 903 columns = source.expression.named_selects 904 905 # in bigquery, unnest structs are automatically scoped as tables, so you can 906 # directly select a struct field in a query. 907 # this handles the case where the unnest is statically defined. 908 if self.schema.dialect == "bigquery": 909 if source.expression.is_type(exp.DataType.Type.STRUCT): 910 for k in source.expression.type.expressions: # type: ignore 911 columns.append(k.name) 912 else: 913 columns = source.expression.named_selects 914 915 node, _ = self.scope.selected_sources.get(name) or (None, None) 916 if isinstance(node, Scope): 917 column_aliases = node.expression.alias_column_names 918 elif isinstance(node, exp.Expression): 919 column_aliases = node.alias_column_names 920 else: 921 column_aliases = [] 922 923 if column_aliases: 924 # If the source's columns are aliased, their aliases shadow the corresponding column names. 925 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 926 columns = [ 927 alias or name 928 for (name, alias) in itertools.zip_longest(columns, column_aliases) 929 ] 930 931 self._get_source_columns_cache[cache_key] = columns 932 933 return self._get_source_columns_cache[cache_key] 934 935 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 936 if self._source_columns is None: 937 self._source_columns = { 938 source_name: self.get_source_columns(source_name) 939 for source_name, source in itertools.chain( 940 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 941 ) 942 } 943 return self._source_columns 944 945 def _get_unambiguous_columns( 946 self, source_columns: t.Dict[str, t.Sequence[str]] 947 ) -> t.Mapping[str, str]: 948 """ 949 Find all the unambiguous columns in sources. 950 951 Args: 952 source_columns: Mapping of names to source columns. 953 954 Returns: 955 Mapping of column name to source name. 956 """ 957 if not source_columns: 958 return {} 959 960 source_columns_pairs = list(source_columns.items()) 961 962 first_table, first_columns = source_columns_pairs[0] 963 964 if len(source_columns_pairs) == 1: 965 # Performance optimization - avoid copying first_columns if there is only one table. 966 return SingleValuedMapping(first_columns, first_table) 967 968 unambiguous_columns = {col: first_table for col in first_columns} 969 all_columns = set(unambiguous_columns) 970 971 for table, columns in source_columns_pairs[1:]: 972 unique = set(columns) 973 ambiguous = all_columns.intersection(unique) 974 all_columns.update(columns) 975 976 for column in ambiguous: 977 unambiguous_columns.pop(column, None) 978 for column in unique.difference(ambiguous): 979 unambiguous_columns[column] = table 980 981 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 60 snowflake_or_oracle = dialect in ("oracle", "snowflake") 61 62 for scope in traverse_scope(expression): 63 scope_expression = scope.expression 64 is_select = isinstance(scope_expression, exp.Select) 65 66 if is_select and snowflake_or_oracle and scope_expression.args.get("connect"): 67 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 68 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 69 scope_expression.transform( 70 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 71 copy=False, 72 ) 73 scope.clear_cache() 74 75 resolver = Resolver(scope, schema, infer_schema=infer_schema) 76 _pop_table_column_aliases(scope.ctes) 77 _pop_table_column_aliases(scope.derived_tables) 78 using_column_tables = _expand_using(scope, resolver) 79 80 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 81 _expand_alias_refs( 82 scope, 83 resolver, 84 dialect, 85 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 86 ) 87 88 _convert_columns_to_dots(scope, resolver) 89 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 107 # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) 108 # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT 109 _expand_order_by_and_distinct_on(scope, resolver) 110 111 if dialect == "bigquery": 112 annotator.annotate_scope(scope) 113 114 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
117def validate_qualify_columns(expression: E) -> E: 118 """Raise an `OptimizeError` if any columns aren't qualified""" 119 all_unqualified_columns = [] 120 for scope in traverse_scope(expression): 121 if isinstance(scope.expression, exp.Select): 122 unqualified_columns = scope.unqualified_columns 123 124 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 125 column = scope.external_columns[0] 126 for_table = f" for table: '{column.table}'" if column.table else "" 127 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 128 129 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 130 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 131 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 132 # this list here to ensure those in the former category will be excluded. 133 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 134 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 135 136 all_unqualified_columns.extend(unqualified_columns) 137 138 if all_unqualified_columns: 139 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 140 141 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
752def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 753 """Ensure all output columns are aliased""" 754 if isinstance(scope_or_expression, exp.Expression): 755 scope = build_scope(scope_or_expression) 756 if not isinstance(scope, Scope): 757 return 758 else: 759 scope = scope_or_expression 760 761 new_selections = [] 762 for i, (selection, aliased_column) in enumerate( 763 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 764 ): 765 if selection is None: 766 break 767 768 if isinstance(selection, exp.Subquery): 769 if not selection.output_name: 770 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 771 elif not isinstance(selection, exp.Alias) and not selection.is_star: 772 selection = alias( 773 selection, 774 alias=selection.output_name or f"_col_{i}", 775 copy=False, 776 ) 777 if aliased_column: 778 selection.set("alias", exp.to_identifier(aliased_column)) 779 780 new_selections.append(selection) 781 782 if isinstance(scope.expression, exp.Select): 783 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, identify: bool = True) -> ~E:
786def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 787 """Makes sure all identifiers that need to be quoted are quoted.""" 788 return expression.transform( 789 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 790 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
793def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 794 """ 795 Pushes down the CTE alias columns into the projection, 796 797 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 798 799 Example: 800 >>> import sqlglot 801 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 802 >>> pushdown_cte_alias_columns(expression).sql() 803 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 804 805 Args: 806 expression: Expression to pushdown. 807 808 Returns: 809 The expression with the CTE aliases pushed down into the projection. 810 """ 811 for cte in expression.find_all(exp.CTE): 812 if cte.alias_column_names: 813 new_expressions = [] 814 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 815 if isinstance(projection, exp.Alias): 816 projection.set("alias", _alias) 817 else: 818 projection = alias(projection, alias=_alias) 819 new_expressions.append(projection) 820 cte.this.set("expressions", new_expressions) 821 822 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
825class Resolver: 826 """ 827 Helper for resolving columns. 828 829 This is a class so we can lazily load some things and easily share them across functions. 830 """ 831 832 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 833 self.scope = scope 834 self.schema = schema 835 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 836 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 837 self._all_columns: t.Optional[t.Set[str]] = None 838 self._infer_schema = infer_schema 839 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 840 841 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 842 """ 843 Get the table for a column name. 844 845 Args: 846 column_name: The column name to find the table for. 847 Returns: 848 The table name if it can be found/inferred. 849 """ 850 if self._unambiguous_columns is None: 851 self._unambiguous_columns = self._get_unambiguous_columns( 852 self._get_all_source_columns() 853 ) 854 855 table_name = self._unambiguous_columns.get(column_name) 856 857 if not table_name and self._infer_schema: 858 sources_without_schema = tuple( 859 source 860 for source, columns in self._get_all_source_columns().items() 861 if not columns or "*" in columns 862 ) 863 if len(sources_without_schema) == 1: 864 table_name = sources_without_schema[0] 865 866 if table_name not in self.scope.selected_sources: 867 return exp.to_identifier(table_name) 868 869 node, _ = self.scope.selected_sources.get(table_name) 870 871 if isinstance(node, exp.Query): 872 while node and node.alias != table_name: 873 node = node.parent 874 875 node_alias = node.args.get("alias") 876 if node_alias: 877 return exp.to_identifier(node_alias.this) 878 879 return exp.to_identifier(table_name) 880 881 @property 882 def all_columns(self) -> t.Set[str]: 883 """All available columns of all sources in this scope""" 884 if self._all_columns is None: 885 self._all_columns = { 886 column for columns in self._get_all_source_columns().values() for column in columns 887 } 888 return self._all_columns 889 890 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 891 """Resolve the source columns for a given source `name`.""" 892 cache_key = (name, only_visible) 893 if cache_key not in self._get_source_columns_cache: 894 if name not in self.scope.sources: 895 raise OptimizeError(f"Unknown table: {name}") 896 897 source = self.scope.sources[name] 898 899 if isinstance(source, exp.Table): 900 columns = self.schema.column_names(source, only_visible) 901 elif isinstance(source, Scope) and isinstance( 902 source.expression, (exp.Values, exp.Unnest) 903 ): 904 columns = source.expression.named_selects 905 906 # in bigquery, unnest structs are automatically scoped as tables, so you can 907 # directly select a struct field in a query. 908 # this handles the case where the unnest is statically defined. 909 if self.schema.dialect == "bigquery": 910 if source.expression.is_type(exp.DataType.Type.STRUCT): 911 for k in source.expression.type.expressions: # type: ignore 912 columns.append(k.name) 913 else: 914 columns = source.expression.named_selects 915 916 node, _ = self.scope.selected_sources.get(name) or (None, None) 917 if isinstance(node, Scope): 918 column_aliases = node.expression.alias_column_names 919 elif isinstance(node, exp.Expression): 920 column_aliases = node.alias_column_names 921 else: 922 column_aliases = [] 923 924 if column_aliases: 925 # If the source's columns are aliased, their aliases shadow the corresponding column names. 926 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 927 columns = [ 928 alias or name 929 for (name, alias) in itertools.zip_longest(columns, column_aliases) 930 ] 931 932 self._get_source_columns_cache[cache_key] = columns 933 934 return self._get_source_columns_cache[cache_key] 935 936 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 937 if self._source_columns is None: 938 self._source_columns = { 939 source_name: self.get_source_columns(source_name) 940 for source_name, source in itertools.chain( 941 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 942 ) 943 } 944 return self._source_columns 945 946 def _get_unambiguous_columns( 947 self, source_columns: t.Dict[str, t.Sequence[str]] 948 ) -> t.Mapping[str, str]: 949 """ 950 Find all the unambiguous columns in sources. 951 952 Args: 953 source_columns: Mapping of names to source columns. 954 955 Returns: 956 Mapping of column name to source name. 957 """ 958 if not source_columns: 959 return {} 960 961 source_columns_pairs = list(source_columns.items()) 962 963 first_table, first_columns = source_columns_pairs[0] 964 965 if len(source_columns_pairs) == 1: 966 # Performance optimization - avoid copying first_columns if there is only one table. 967 return SingleValuedMapping(first_columns, first_table) 968 969 unambiguous_columns = {col: first_table for col in first_columns} 970 all_columns = set(unambiguous_columns) 971 972 for table, columns in source_columns_pairs[1:]: 973 unique = set(columns) 974 ambiguous = all_columns.intersection(unique) 975 all_columns.update(columns) 976 977 for column in ambiguous: 978 unambiguous_columns.pop(column, None) 979 for column in unique.difference(ambiguous): 980 unambiguous_columns[column] = table 981 982 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
832 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 833 self.scope = scope 834 self.schema = schema 835 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 836 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 837 self._all_columns: t.Optional[t.Set[str]] = None 838 self._infer_schema = infer_schema 839 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
841 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 842 """ 843 Get the table for a column name. 844 845 Args: 846 column_name: The column name to find the table for. 847 Returns: 848 The table name if it can be found/inferred. 849 """ 850 if self._unambiguous_columns is None: 851 self._unambiguous_columns = self._get_unambiguous_columns( 852 self._get_all_source_columns() 853 ) 854 855 table_name = self._unambiguous_columns.get(column_name) 856 857 if not table_name and self._infer_schema: 858 sources_without_schema = tuple( 859 source 860 for source, columns in self._get_all_source_columns().items() 861 if not columns or "*" in columns 862 ) 863 if len(sources_without_schema) == 1: 864 table_name = sources_without_schema[0] 865 866 if table_name not in self.scope.selected_sources: 867 return exp.to_identifier(table_name) 868 869 node, _ = self.scope.selected_sources.get(table_name) 870 871 if isinstance(node, exp.Query): 872 while node and node.alias != table_name: 873 node = node.parent 874 875 node_alias = node.args.get("alias") 876 if node_alias: 877 return exp.to_identifier(node_alias.this) 878 879 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
881 @property 882 def all_columns(self) -> t.Set[str]: 883 """All available columns of all sources in this scope""" 884 if self._all_columns is None: 885 self._all_columns = { 886 column for columns in self._get_all_source_columns().values() for column in columns 887 } 888 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
890 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 891 """Resolve the source columns for a given source `name`.""" 892 cache_key = (name, only_visible) 893 if cache_key not in self._get_source_columns_cache: 894 if name not in self.scope.sources: 895 raise OptimizeError(f"Unknown table: {name}") 896 897 source = self.scope.sources[name] 898 899 if isinstance(source, exp.Table): 900 columns = self.schema.column_names(source, only_visible) 901 elif isinstance(source, Scope) and isinstance( 902 source.expression, (exp.Values, exp.Unnest) 903 ): 904 columns = source.expression.named_selects 905 906 # in bigquery, unnest structs are automatically scoped as tables, so you can 907 # directly select a struct field in a query. 908 # this handles the case where the unnest is statically defined. 909 if self.schema.dialect == "bigquery": 910 if source.expression.is_type(exp.DataType.Type.STRUCT): 911 for k in source.expression.type.expressions: # type: ignore 912 columns.append(k.name) 913 else: 914 columns = source.expression.named_selects 915 916 node, _ = self.scope.selected_sources.get(name) or (None, None) 917 if isinstance(node, Scope): 918 column_aliases = node.expression.alias_column_names 919 elif isinstance(node, exp.Expression): 920 column_aliases = node.alias_column_names 921 else: 922 column_aliases = [] 923 924 if column_aliases: 925 # If the source's columns are aliased, their aliases shadow the corresponding column names. 926 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 927 columns = [ 928 alias or name 929 for (name, alias) in itertools.zip_longest(columns, column_aliases) 930 ] 931 932 self._get_source_columns_cache[cache_key] = columns 933 934 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.