sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 59 snowflake_or_oracle = dialect in ("oracle", "snowflake") 60 61 for scope in traverse_scope(expression): 62 scope_expression = scope.expression 63 is_select = isinstance(scope_expression, exp.Select) 64 65 if is_select and snowflake_or_oracle and scope_expression.args.get("connect"): 66 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 67 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 68 scope_expression.transform( 69 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 70 copy=False, 71 ) 72 scope.clear_cache() 73 74 resolver = Resolver(scope, schema, infer_schema=infer_schema) 75 _pop_table_column_aliases(scope.ctes) 76 _pop_table_column_aliases(scope.derived_tables) 77 using_column_tables = _expand_using(scope, resolver) 78 79 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 80 _expand_alias_refs( 81 scope, 82 resolver, 83 dialect, 84 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 85 ) 86 87 _convert_columns_to_dots(scope, resolver) 88 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 89 90 if not schema.empty and expand_alias_refs: 91 _expand_alias_refs(scope, resolver, dialect) 92 93 if is_select: 94 if expand_stars: 95 _expand_stars( 96 scope, 97 resolver, 98 using_column_tables, 99 pseudocolumns, 100 annotator, 101 ) 102 qualify_outputs(scope) 103 104 _expand_group_by(scope, dialect) 105 _expand_order_by(scope, resolver) 106 107 if dialect == "bigquery": 108 annotator.annotate_scope(scope) 109 110 return expression 111 112 113def validate_qualify_columns(expression: E) -> E: 114 """Raise an `OptimizeError` if any columns aren't qualified""" 115 all_unqualified_columns = [] 116 for scope in traverse_scope(expression): 117 if isinstance(scope.expression, exp.Select): 118 unqualified_columns = scope.unqualified_columns 119 120 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 121 column = scope.external_columns[0] 122 for_table = f" for table: '{column.table}'" if column.table else "" 123 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 124 125 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 126 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 127 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 128 # this list here to ensure those in the former category will be excluded. 129 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 130 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 131 132 all_unqualified_columns.extend(unqualified_columns) 133 134 if all_unqualified_columns: 135 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 136 137 return expression 138 139 140def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 141 name_column = [] 142 field = unpivot.args.get("field") 143 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 144 name_column.append(field.this) 145 146 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 147 return itertools.chain(name_column, value_columns) 148 149 150def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 151 """ 152 Remove table column aliases. 153 154 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 155 """ 156 for derived_table in derived_tables: 157 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 158 continue 159 table_alias = derived_table.args.get("alias") 160 if table_alias: 161 table_alias.args.pop("columns", None) 162 163 164def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 165 columns = {} 166 167 def _update_source_columns(source_name: str) -> None: 168 for column_name in resolver.get_source_columns(source_name): 169 if column_name not in columns: 170 columns[column_name] = source_name 171 172 joins = list(scope.find_all(exp.Join)) 173 names = {join.alias_or_name for join in joins} 174 ordered = [key for key in scope.selected_sources if key not in names] 175 176 if names and not ordered: 177 raise OptimizeError(f"Joins {names} missing source table {scope.expression}") 178 179 # Mapping of automatically joined column names to an ordered set of source names (dict). 180 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 181 182 for source_name in ordered: 183 _update_source_columns(source_name) 184 185 for i, join in enumerate(joins): 186 source_table = ordered[-1] 187 if source_table: 188 _update_source_columns(source_table) 189 190 join_table = join.alias_or_name 191 ordered.append(join_table) 192 193 using = join.args.get("using") 194 if not using: 195 continue 196 197 join_columns = resolver.get_source_columns(join_table) 198 conditions = [] 199 using_identifier_count = len(using) 200 is_semi_or_anti_join = join.is_semi_or_anti_join 201 202 for identifier in using: 203 identifier = identifier.name 204 table = columns.get(identifier) 205 206 if not table or identifier not in join_columns: 207 if (columns and "*" not in columns) and join_columns: 208 raise OptimizeError(f"Cannot automatically join: {identifier}") 209 210 table = table or source_table 211 212 if i == 0 or using_identifier_count == 1: 213 lhs: exp.Expression = exp.column(identifier, table=table) 214 else: 215 coalesce_columns = [ 216 exp.column(identifier, table=t) 217 for t in ordered[:-1] 218 if identifier in resolver.get_source_columns(t) 219 ] 220 if len(coalesce_columns) > 1: 221 lhs = exp.func("coalesce", *coalesce_columns) 222 else: 223 lhs = exp.column(identifier, table=table) 224 225 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 226 227 # Set all values in the dict to None, because we only care about the key ordering 228 tables = column_tables.setdefault(identifier, {}) 229 230 # Do not update the dict if this was a SEMI/ANTI join in 231 # order to avoid generating COALESCE columns for this join pair 232 if not is_semi_or_anti_join: 233 if table not in tables: 234 tables[table] = None 235 if join_table not in tables: 236 tables[join_table] = None 237 238 join.args.pop("using") 239 join.set("on", exp.and_(*conditions, copy=False)) 240 241 if column_tables: 242 for column in scope.columns: 243 if not column.table and column.name in column_tables: 244 tables = column_tables[column.name] 245 coalesce_args = [exp.column(column.name, table=table) for table in tables] 246 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 247 248 if isinstance(column.parent, exp.Select): 249 # Ensure the USING column keeps its name if it's projected 250 replacement = alias(replacement, alias=column.name, copy=False) 251 elif isinstance(column.parent, exp.Struct): 252 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 253 replacement = exp.PropertyEQ( 254 this=exp.to_identifier(column.name), expression=replacement 255 ) 256 257 scope.replace(column, replacement) 258 259 return column_tables 260 261 262def _expand_alias_refs( 263 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 264) -> None: 265 """ 266 Expand references to aliases. 267 Example: 268 SELECT y.foo AS bar, bar * 2 AS baz FROM y 269 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 270 """ 271 expression = scope.expression 272 273 if not isinstance(expression, exp.Select): 274 return 275 276 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 277 278 def replace_columns( 279 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 280 ) -> None: 281 is_group_by = isinstance(node, exp.Group) 282 if not node or (expand_only_groupby and not is_group_by): 283 return 284 285 for column in walk_in_scope(node, prune=lambda node: node.is_star): 286 if not isinstance(column, exp.Column): 287 continue 288 289 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 290 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 291 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 292 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 293 if expand_only_groupby and is_group_by and column.parent is not node: 294 continue 295 296 table = resolver.get_table(column.name) if resolve_table and not column.table else None 297 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 298 double_agg = ( 299 ( 300 alias_expr.find(exp.AggFunc) 301 and ( 302 column.find_ancestor(exp.AggFunc) 303 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 304 ) 305 ) 306 if alias_expr 307 else False 308 ) 309 310 if table and (not alias_expr or double_agg): 311 column.set("table", table) 312 elif not column.table and alias_expr and not double_agg: 313 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 314 if literal_index: 315 column.replace(exp.Literal.number(i)) 316 else: 317 column = column.replace(exp.paren(alias_expr)) 318 simplified = simplify_parens(column) 319 if simplified is not column: 320 column.replace(simplified) 321 322 for i, projection in enumerate(expression.selects): 323 replace_columns(projection) 324 if isinstance(projection, exp.Alias): 325 alias_to_expression[projection.alias] = (projection.this, i + 1) 326 327 parent_scope = scope 328 while parent_scope.is_union: 329 parent_scope = parent_scope.parent 330 331 # We shouldn't expand aliases if they match the recursive CTE's columns 332 if parent_scope.is_cte: 333 cte = parent_scope.expression.parent 334 if cte.find_ancestor(exp.With).recursive: 335 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 336 alias_to_expression.pop(recursive_cte_column.output_name, None) 337 338 replace_columns(expression.args.get("where")) 339 replace_columns(expression.args.get("group"), literal_index=True) 340 replace_columns(expression.args.get("having"), resolve_table=True) 341 replace_columns(expression.args.get("qualify"), resolve_table=True) 342 343 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 344 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 345 if dialect == "snowflake": 346 for join in expression.args.get("joins") or []: 347 replace_columns(join) 348 349 scope.clear_cache() 350 351 352def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 353 expression = scope.expression 354 group = expression.args.get("group") 355 if not group: 356 return 357 358 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 359 expression.set("group", group) 360 361 362def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 363 order = scope.expression.args.get("order") 364 if not order: 365 return 366 367 ordereds = order.expressions 368 for ordered, new_expression in zip( 369 ordereds, 370 _expand_positional_references( 371 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 372 ), 373 ): 374 for agg in ordered.find_all(exp.AggFunc): 375 for col in agg.find_all(exp.Column): 376 if not col.table: 377 col.set("table", resolver.get_table(col.name)) 378 379 ordered.set("this", new_expression) 380 381 if scope.expression.args.get("group"): 382 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 383 384 for ordered in ordereds: 385 ordered = ordered.this 386 387 ordered.replace( 388 exp.to_identifier(_select_by_pos(scope, ordered).alias) 389 if ordered.is_int 390 else selects.get(ordered, ordered) 391 ) 392 393 394def _expand_positional_references( 395 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 396) -> t.List[exp.Expression]: 397 new_nodes: t.List[exp.Expression] = [] 398 ambiguous_projections = None 399 400 for node in expressions: 401 if node.is_int: 402 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 403 404 if alias: 405 new_nodes.append(exp.column(select.args["alias"].copy())) 406 else: 407 select = select.this 408 409 if dialect == "bigquery": 410 if ambiguous_projections is None: 411 # When a projection name is also a source name and it is referenced in the 412 # GROUP BY clause, BQ can't understand what the identifier corresponds to 413 ambiguous_projections = { 414 s.alias_or_name 415 for s in scope.expression.selects 416 if s.alias_or_name in scope.selected_sources 417 } 418 419 ambiguous = any( 420 column.parts[0].name in ambiguous_projections 421 for column in select.find_all(exp.Column) 422 ) 423 else: 424 ambiguous = False 425 426 if ( 427 isinstance(select, exp.CONSTANTS) 428 or select.find(exp.Explode, exp.Unnest) 429 or ambiguous 430 ): 431 new_nodes.append(node) 432 else: 433 new_nodes.append(select.copy()) 434 else: 435 new_nodes.append(node) 436 437 return new_nodes 438 439 440def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 441 try: 442 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 443 except IndexError: 444 raise OptimizeError(f"Unknown output column: {node.name}") 445 446 447def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 448 """ 449 Converts `Column` instances that represent struct field lookup into chained `Dots`. 450 451 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 452 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 453 """ 454 converted = False 455 for column in itertools.chain(scope.columns, scope.stars): 456 if isinstance(column, exp.Dot): 457 continue 458 459 column_table: t.Optional[str | exp.Identifier] = column.table 460 if ( 461 column_table 462 and column_table not in scope.sources 463 and ( 464 not scope.parent 465 or column_table not in scope.parent.sources 466 or not scope.is_correlated_subquery 467 ) 468 ): 469 root, *parts = column.parts 470 471 if root.name in scope.sources: 472 # The struct is already qualified, but we still need to change the AST 473 column_table = root 474 root, *parts = parts 475 else: 476 column_table = resolver.get_table(root.name) 477 478 if column_table: 479 converted = True 480 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 481 482 if converted: 483 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 484 # a `for column in scope.columns` iteration, even though they shouldn't be 485 scope.clear_cache() 486 487 488def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 489 """Disambiguate columns, ensuring each column specifies a source""" 490 for column in scope.columns: 491 column_table = column.table 492 column_name = column.name 493 494 if column_table and column_table in scope.sources: 495 source_columns = resolver.get_source_columns(column_table) 496 if ( 497 not allow_partial_qualification 498 and source_columns 499 and column_name not in source_columns 500 and "*" not in source_columns 501 ): 502 raise OptimizeError(f"Unknown column: {column_name}") 503 504 if not column_table: 505 if scope.pivots and not column.find_ancestor(exp.Pivot): 506 # If the column is under the Pivot expression, we need to qualify it 507 # using the name of the pivoted source instead of the pivot's alias 508 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 509 continue 510 511 # column_table can be a '' because bigquery unnest has no table alias 512 column_table = resolver.get_table(column_name) 513 if column_table: 514 column.set("table", column_table) 515 516 for pivot in scope.pivots: 517 for column in pivot.find_all(exp.Column): 518 if not column.table and column.name in resolver.all_columns: 519 column_table = resolver.get_table(column.name) 520 if column_table: 521 column.set("table", column_table) 522 523 524def _expand_struct_stars( 525 expression: exp.Dot, 526) -> t.List[exp.Alias]: 527 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 528 529 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 530 if not dot_column.is_type(exp.DataType.Type.STRUCT): 531 return [] 532 533 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 534 dot_column = dot_column.copy() 535 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 536 537 # First part is the table name and last part is the star so they can be dropped 538 dot_parts = expression.parts[1:-1] 539 540 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 541 for part in dot_parts[1:]: 542 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 543 # Unable to expand star unless all fields are named 544 if not isinstance(field.this, exp.Identifier): 545 return [] 546 547 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 548 starting_struct = field 549 break 550 else: 551 # There is no matching field in the struct 552 return [] 553 554 taken_names = set() 555 new_selections = [] 556 557 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 558 name = field.name 559 560 # Ambiguous or anonymous fields can't be expanded 561 if name in taken_names or not isinstance(field.this, exp.Identifier): 562 return [] 563 564 taken_names.add(name) 565 566 this = field.this.copy() 567 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 568 new_column = exp.column( 569 t.cast(exp.Identifier, root), 570 table=dot_column.args.get("table"), 571 fields=t.cast(t.List[exp.Identifier], parts), 572 ) 573 new_selections.append(alias(new_column, this, copy=False)) 574 575 return new_selections 576 577 578def _expand_stars( 579 scope: Scope, 580 resolver: Resolver, 581 using_column_tables: t.Dict[str, t.Any], 582 pseudocolumns: t.Set[str], 583 annotator: TypeAnnotator, 584) -> None: 585 """Expand stars to lists of column selections""" 586 587 new_selections: t.List[exp.Expression] = [] 588 except_columns: t.Dict[int, t.Set[str]] = {} 589 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 590 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 591 592 coalesced_columns = set() 593 dialect = resolver.schema.dialect 594 595 pivot_output_columns = None 596 pivot_exclude_columns = None 597 598 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 599 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 600 if pivot.unpivot: 601 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 602 603 field = pivot.args.get("field") 604 if isinstance(field, exp.In): 605 pivot_exclude_columns = { 606 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 607 } 608 else: 609 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 610 611 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 612 if not pivot_output_columns: 613 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 614 615 is_bigquery = dialect == "bigquery" 616 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 617 # Found struct expansion, annotate scope ahead of time 618 annotator.annotate_scope(scope) 619 620 for expression in scope.expression.selects: 621 tables = [] 622 if isinstance(expression, exp.Star): 623 tables.extend(scope.selected_sources) 624 _add_except_columns(expression, tables, except_columns) 625 _add_replace_columns(expression, tables, replace_columns) 626 _add_rename_columns(expression, tables, rename_columns) 627 elif expression.is_star: 628 if not isinstance(expression, exp.Dot): 629 tables.append(expression.table) 630 _add_except_columns(expression.this, tables, except_columns) 631 _add_replace_columns(expression.this, tables, replace_columns) 632 _add_rename_columns(expression.this, tables, rename_columns) 633 elif is_bigquery: 634 struct_fields = _expand_struct_stars(expression) 635 if struct_fields: 636 new_selections.extend(struct_fields) 637 continue 638 639 if not tables: 640 new_selections.append(expression) 641 continue 642 643 for table in tables: 644 if table not in scope.sources: 645 raise OptimizeError(f"Unknown table: {table}") 646 647 columns = resolver.get_source_columns(table, only_visible=True) 648 columns = columns or scope.outer_columns 649 650 if pseudocolumns: 651 columns = [name for name in columns if name.upper() not in pseudocolumns] 652 653 if not columns or "*" in columns: 654 return 655 656 table_id = id(table) 657 columns_to_exclude = except_columns.get(table_id) or set() 658 renamed_columns = rename_columns.get(table_id, {}) 659 replaced_columns = replace_columns.get(table_id, {}) 660 661 if pivot: 662 if pivot_output_columns and pivot_exclude_columns: 663 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 664 pivot_columns.extend(pivot_output_columns) 665 else: 666 pivot_columns = pivot.alias_column_names 667 668 if pivot_columns: 669 new_selections.extend( 670 alias(exp.column(name, table=pivot.alias), name, copy=False) 671 for name in pivot_columns 672 if name not in columns_to_exclude 673 ) 674 continue 675 676 for name in columns: 677 if name in columns_to_exclude or name in coalesced_columns: 678 continue 679 if name in using_column_tables and table in using_column_tables[name]: 680 coalesced_columns.add(name) 681 tables = using_column_tables[name] 682 coalesce_args = [exp.column(name, table=table) for table in tables] 683 684 new_selections.append( 685 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 686 ) 687 else: 688 alias_ = renamed_columns.get(name, name) 689 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 690 new_selections.append( 691 alias(selection_expr, alias_, copy=False) 692 if alias_ != name 693 else selection_expr 694 ) 695 696 # Ensures we don't overwrite the initial selections with an empty list 697 if new_selections and isinstance(scope.expression, exp.Select): 698 scope.expression.set("expressions", new_selections) 699 700 701def _add_except_columns( 702 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 703) -> None: 704 except_ = expression.args.get("except") 705 706 if not except_: 707 return 708 709 columns = {e.name for e in except_} 710 711 for table in tables: 712 except_columns[id(table)] = columns 713 714 715def _add_rename_columns( 716 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 717) -> None: 718 rename = expression.args.get("rename") 719 720 if not rename: 721 return 722 723 columns = {e.this.name: e.alias for e in rename} 724 725 for table in tables: 726 rename_columns[id(table)] = columns 727 728 729def _add_replace_columns( 730 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 731) -> None: 732 replace = expression.args.get("replace") 733 734 if not replace: 735 return 736 737 columns = {e.alias: e for e in replace} 738 739 for table in tables: 740 replace_columns[id(table)] = columns 741 742 743def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 744 """Ensure all output columns are aliased""" 745 if isinstance(scope_or_expression, exp.Expression): 746 scope = build_scope(scope_or_expression) 747 if not isinstance(scope, Scope): 748 return 749 else: 750 scope = scope_or_expression 751 752 new_selections = [] 753 for i, (selection, aliased_column) in enumerate( 754 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 755 ): 756 if selection is None: 757 break 758 759 if isinstance(selection, exp.Subquery): 760 if not selection.output_name: 761 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 762 elif not isinstance(selection, exp.Alias) and not selection.is_star: 763 selection = alias( 764 selection, 765 alias=selection.output_name or f"_col_{i}", 766 copy=False, 767 ) 768 if aliased_column: 769 selection.set("alias", exp.to_identifier(aliased_column)) 770 771 new_selections.append(selection) 772 773 if isinstance(scope.expression, exp.Select): 774 scope.expression.set("expressions", new_selections) 775 776 777def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 778 """Makes sure all identifiers that need to be quoted are quoted.""" 779 return expression.transform( 780 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 781 ) # type: ignore 782 783 784def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 785 """ 786 Pushes down the CTE alias columns into the projection, 787 788 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 789 790 Example: 791 >>> import sqlglot 792 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 793 >>> pushdown_cte_alias_columns(expression).sql() 794 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 795 796 Args: 797 expression: Expression to pushdown. 798 799 Returns: 800 The expression with the CTE aliases pushed down into the projection. 801 """ 802 for cte in expression.find_all(exp.CTE): 803 if cte.alias_column_names: 804 new_expressions = [] 805 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 806 if isinstance(projection, exp.Alias): 807 projection.set("alias", _alias) 808 else: 809 projection = alias(projection, alias=_alias) 810 new_expressions.append(projection) 811 cte.this.set("expressions", new_expressions) 812 813 return expression 814 815 816class Resolver: 817 """ 818 Helper for resolving columns. 819 820 This is a class so we can lazily load some things and easily share them across functions. 821 """ 822 823 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 824 self.scope = scope 825 self.schema = schema 826 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 827 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 828 self._all_columns: t.Optional[t.Set[str]] = None 829 self._infer_schema = infer_schema 830 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 831 832 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 833 """ 834 Get the table for a column name. 835 836 Args: 837 column_name: The column name to find the table for. 838 Returns: 839 The table name if it can be found/inferred. 840 """ 841 if self._unambiguous_columns is None: 842 self._unambiguous_columns = self._get_unambiguous_columns( 843 self._get_all_source_columns() 844 ) 845 846 table_name = self._unambiguous_columns.get(column_name) 847 848 if not table_name and self._infer_schema: 849 sources_without_schema = tuple( 850 source 851 for source, columns in self._get_all_source_columns().items() 852 if not columns or "*" in columns 853 ) 854 if len(sources_without_schema) == 1: 855 table_name = sources_without_schema[0] 856 857 if table_name not in self.scope.selected_sources: 858 return exp.to_identifier(table_name) 859 860 node, _ = self.scope.selected_sources.get(table_name) 861 862 if isinstance(node, exp.Query): 863 while node and node.alias != table_name: 864 node = node.parent 865 866 node_alias = node.args.get("alias") 867 if node_alias: 868 return exp.to_identifier(node_alias.this) 869 870 return exp.to_identifier(table_name) 871 872 @property 873 def all_columns(self) -> t.Set[str]: 874 """All available columns of all sources in this scope""" 875 if self._all_columns is None: 876 self._all_columns = { 877 column for columns in self._get_all_source_columns().values() for column in columns 878 } 879 return self._all_columns 880 881 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 882 """Resolve the source columns for a given source `name`.""" 883 cache_key = (name, only_visible) 884 if cache_key not in self._get_source_columns_cache: 885 if name not in self.scope.sources: 886 raise OptimizeError(f"Unknown table: {name}") 887 888 source = self.scope.sources[name] 889 890 if isinstance(source, exp.Table): 891 columns = self.schema.column_names(source, only_visible) 892 elif isinstance(source, Scope) and isinstance( 893 source.expression, (exp.Values, exp.Unnest) 894 ): 895 columns = source.expression.named_selects 896 897 # in bigquery, unnest structs are automatically scoped as tables, so you can 898 # directly select a struct field in a query. 899 # this handles the case where the unnest is statically defined. 900 if self.schema.dialect == "bigquery": 901 if source.expression.is_type(exp.DataType.Type.STRUCT): 902 for k in source.expression.type.expressions: # type: ignore 903 columns.append(k.name) 904 else: 905 columns = source.expression.named_selects 906 907 node, _ = self.scope.selected_sources.get(name) or (None, None) 908 if isinstance(node, Scope): 909 column_aliases = node.expression.alias_column_names 910 elif isinstance(node, exp.Expression): 911 column_aliases = node.alias_column_names 912 else: 913 column_aliases = [] 914 915 if column_aliases: 916 # If the source's columns are aliased, their aliases shadow the corresponding column names. 917 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 918 columns = [ 919 alias or name 920 for (name, alias) in itertools.zip_longest(columns, column_aliases) 921 ] 922 923 self._get_source_columns_cache[cache_key] = columns 924 925 return self._get_source_columns_cache[cache_key] 926 927 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 928 if self._source_columns is None: 929 self._source_columns = { 930 source_name: self.get_source_columns(source_name) 931 for source_name, source in itertools.chain( 932 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 933 ) 934 } 935 return self._source_columns 936 937 def _get_unambiguous_columns( 938 self, source_columns: t.Dict[str, t.Sequence[str]] 939 ) -> t.Mapping[str, str]: 940 """ 941 Find all the unambiguous columns in sources. 942 943 Args: 944 source_columns: Mapping of names to source columns. 945 946 Returns: 947 Mapping of column name to source name. 948 """ 949 if not source_columns: 950 return {} 951 952 source_columns_pairs = list(source_columns.items()) 953 954 first_table, first_columns = source_columns_pairs[0] 955 956 if len(source_columns_pairs) == 1: 957 # Performance optimization - avoid copying first_columns if there is only one table. 958 return SingleValuedMapping(first_columns, first_table) 959 960 unambiguous_columns = {col: first_table for col in first_columns} 961 all_columns = set(unambiguous_columns) 962 963 for table, columns in source_columns_pairs[1:]: 964 unique = set(columns) 965 ambiguous = all_columns.intersection(unique) 966 all_columns.update(columns) 967 968 for column in ambiguous: 969 unambiguous_columns.pop(column, None) 970 for column in unique.difference(ambiguous): 971 unambiguous_columns[column] = table 972 973 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 60 snowflake_or_oracle = dialect in ("oracle", "snowflake") 61 62 for scope in traverse_scope(expression): 63 scope_expression = scope.expression 64 is_select = isinstance(scope_expression, exp.Select) 65 66 if is_select and snowflake_or_oracle and scope_expression.args.get("connect"): 67 # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL 68 # pseudocolumn, which doesn't belong to a table, so we change it into an identifier 69 scope_expression.transform( 70 lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, 71 copy=False, 72 ) 73 scope.clear_cache() 74 75 resolver = Resolver(scope, schema, infer_schema=infer_schema) 76 _pop_table_column_aliases(scope.ctes) 77 _pop_table_column_aliases(scope.derived_tables) 78 using_column_tables = _expand_using(scope, resolver) 79 80 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 81 _expand_alias_refs( 82 scope, 83 resolver, 84 dialect, 85 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 86 ) 87 88 _convert_columns_to_dots(scope, resolver) 89 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 90 91 if not schema.empty and expand_alias_refs: 92 _expand_alias_refs(scope, resolver, dialect) 93 94 if is_select: 95 if expand_stars: 96 _expand_stars( 97 scope, 98 resolver, 99 using_column_tables, 100 pseudocolumns, 101 annotator, 102 ) 103 qualify_outputs(scope) 104 105 _expand_group_by(scope, dialect) 106 _expand_order_by(scope, resolver) 107 108 if dialect == "bigquery": 109 annotator.annotate_scope(scope) 110 111 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
114def validate_qualify_columns(expression: E) -> E: 115 """Raise an `OptimizeError` if any columns aren't qualified""" 116 all_unqualified_columns = [] 117 for scope in traverse_scope(expression): 118 if isinstance(scope.expression, exp.Select): 119 unqualified_columns = scope.unqualified_columns 120 121 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 122 column = scope.external_columns[0] 123 for_table = f" for table: '{column.table}'" if column.table else "" 124 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 125 126 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 127 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 128 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 129 # this list here to ensure those in the former category will be excluded. 130 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 131 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 132 133 all_unqualified_columns.extend(unqualified_columns) 134 135 if all_unqualified_columns: 136 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 137 138 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
744def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 745 """Ensure all output columns are aliased""" 746 if isinstance(scope_or_expression, exp.Expression): 747 scope = build_scope(scope_or_expression) 748 if not isinstance(scope, Scope): 749 return 750 else: 751 scope = scope_or_expression 752 753 new_selections = [] 754 for i, (selection, aliased_column) in enumerate( 755 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 756 ): 757 if selection is None: 758 break 759 760 if isinstance(selection, exp.Subquery): 761 if not selection.output_name: 762 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 763 elif not isinstance(selection, exp.Alias) and not selection.is_star: 764 selection = alias( 765 selection, 766 alias=selection.output_name or f"_col_{i}", 767 copy=False, 768 ) 769 if aliased_column: 770 selection.set("alias", exp.to_identifier(aliased_column)) 771 772 new_selections.append(selection) 773 774 if isinstance(scope.expression, exp.Select): 775 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
778def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 779 """Makes sure all identifiers that need to be quoted are quoted.""" 780 return expression.transform( 781 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 782 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
785def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 786 """ 787 Pushes down the CTE alias columns into the projection, 788 789 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 790 791 Example: 792 >>> import sqlglot 793 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 794 >>> pushdown_cte_alias_columns(expression).sql() 795 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 796 797 Args: 798 expression: Expression to pushdown. 799 800 Returns: 801 The expression with the CTE aliases pushed down into the projection. 802 """ 803 for cte in expression.find_all(exp.CTE): 804 if cte.alias_column_names: 805 new_expressions = [] 806 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 807 if isinstance(projection, exp.Alias): 808 projection.set("alias", _alias) 809 else: 810 projection = alias(projection, alias=_alias) 811 new_expressions.append(projection) 812 cte.this.set("expressions", new_expressions) 813 814 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
817class Resolver: 818 """ 819 Helper for resolving columns. 820 821 This is a class so we can lazily load some things and easily share them across functions. 822 """ 823 824 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 825 self.scope = scope 826 self.schema = schema 827 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 828 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 829 self._all_columns: t.Optional[t.Set[str]] = None 830 self._infer_schema = infer_schema 831 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 832 833 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 834 """ 835 Get the table for a column name. 836 837 Args: 838 column_name: The column name to find the table for. 839 Returns: 840 The table name if it can be found/inferred. 841 """ 842 if self._unambiguous_columns is None: 843 self._unambiguous_columns = self._get_unambiguous_columns( 844 self._get_all_source_columns() 845 ) 846 847 table_name = self._unambiguous_columns.get(column_name) 848 849 if not table_name and self._infer_schema: 850 sources_without_schema = tuple( 851 source 852 for source, columns in self._get_all_source_columns().items() 853 if not columns or "*" in columns 854 ) 855 if len(sources_without_schema) == 1: 856 table_name = sources_without_schema[0] 857 858 if table_name not in self.scope.selected_sources: 859 return exp.to_identifier(table_name) 860 861 node, _ = self.scope.selected_sources.get(table_name) 862 863 if isinstance(node, exp.Query): 864 while node and node.alias != table_name: 865 node = node.parent 866 867 node_alias = node.args.get("alias") 868 if node_alias: 869 return exp.to_identifier(node_alias.this) 870 871 return exp.to_identifier(table_name) 872 873 @property 874 def all_columns(self) -> t.Set[str]: 875 """All available columns of all sources in this scope""" 876 if self._all_columns is None: 877 self._all_columns = { 878 column for columns in self._get_all_source_columns().values() for column in columns 879 } 880 return self._all_columns 881 882 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 883 """Resolve the source columns for a given source `name`.""" 884 cache_key = (name, only_visible) 885 if cache_key not in self._get_source_columns_cache: 886 if name not in self.scope.sources: 887 raise OptimizeError(f"Unknown table: {name}") 888 889 source = self.scope.sources[name] 890 891 if isinstance(source, exp.Table): 892 columns = self.schema.column_names(source, only_visible) 893 elif isinstance(source, Scope) and isinstance( 894 source.expression, (exp.Values, exp.Unnest) 895 ): 896 columns = source.expression.named_selects 897 898 # in bigquery, unnest structs are automatically scoped as tables, so you can 899 # directly select a struct field in a query. 900 # this handles the case where the unnest is statically defined. 901 if self.schema.dialect == "bigquery": 902 if source.expression.is_type(exp.DataType.Type.STRUCT): 903 for k in source.expression.type.expressions: # type: ignore 904 columns.append(k.name) 905 else: 906 columns = source.expression.named_selects 907 908 node, _ = self.scope.selected_sources.get(name) or (None, None) 909 if isinstance(node, Scope): 910 column_aliases = node.expression.alias_column_names 911 elif isinstance(node, exp.Expression): 912 column_aliases = node.alias_column_names 913 else: 914 column_aliases = [] 915 916 if column_aliases: 917 # If the source's columns are aliased, their aliases shadow the corresponding column names. 918 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 919 columns = [ 920 alias or name 921 for (name, alias) in itertools.zip_longest(columns, column_aliases) 922 ] 923 924 self._get_source_columns_cache[cache_key] = columns 925 926 return self._get_source_columns_cache[cache_key] 927 928 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 929 if self._source_columns is None: 930 self._source_columns = { 931 source_name: self.get_source_columns(source_name) 932 for source_name, source in itertools.chain( 933 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 934 ) 935 } 936 return self._source_columns 937 938 def _get_unambiguous_columns( 939 self, source_columns: t.Dict[str, t.Sequence[str]] 940 ) -> t.Mapping[str, str]: 941 """ 942 Find all the unambiguous columns in sources. 943 944 Args: 945 source_columns: Mapping of names to source columns. 946 947 Returns: 948 Mapping of column name to source name. 949 """ 950 if not source_columns: 951 return {} 952 953 source_columns_pairs = list(source_columns.items()) 954 955 first_table, first_columns = source_columns_pairs[0] 956 957 if len(source_columns_pairs) == 1: 958 # Performance optimization - avoid copying first_columns if there is only one table. 959 return SingleValuedMapping(first_columns, first_table) 960 961 unambiguous_columns = {col: first_table for col in first_columns} 962 all_columns = set(unambiguous_columns) 963 964 for table, columns in source_columns_pairs[1:]: 965 unique = set(columns) 966 ambiguous = all_columns.intersection(unique) 967 all_columns.update(columns) 968 969 for column in ambiguous: 970 unambiguous_columns.pop(column, None) 971 for column in unique.difference(ambiguous): 972 unambiguous_columns[column] = table 973 974 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
824 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 825 self.scope = scope 826 self.schema = schema 827 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 828 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 829 self._all_columns: t.Optional[t.Set[str]] = None 830 self._infer_schema = infer_schema 831 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
833 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 834 """ 835 Get the table for a column name. 836 837 Args: 838 column_name: The column name to find the table for. 839 Returns: 840 The table name if it can be found/inferred. 841 """ 842 if self._unambiguous_columns is None: 843 self._unambiguous_columns = self._get_unambiguous_columns( 844 self._get_all_source_columns() 845 ) 846 847 table_name = self._unambiguous_columns.get(column_name) 848 849 if not table_name and self._infer_schema: 850 sources_without_schema = tuple( 851 source 852 for source, columns in self._get_all_source_columns().items() 853 if not columns or "*" in columns 854 ) 855 if len(sources_without_schema) == 1: 856 table_name = sources_without_schema[0] 857 858 if table_name not in self.scope.selected_sources: 859 return exp.to_identifier(table_name) 860 861 node, _ = self.scope.selected_sources.get(table_name) 862 863 if isinstance(node, exp.Query): 864 while node and node.alias != table_name: 865 node = node.parent 866 867 node_alias = node.args.get("alias") 868 if node_alias: 869 return exp.to_identifier(node_alias.this) 870 871 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
873 @property 874 def all_columns(self) -> t.Set[str]: 875 """All available columns of all sources in this scope""" 876 if self._all_columns is None: 877 self._all_columns = { 878 column for columns in self._get_all_source_columns().values() for column in columns 879 } 880 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
882 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 883 """Resolve the source columns for a given source `name`.""" 884 cache_key = (name, only_visible) 885 if cache_key not in self._get_source_columns_cache: 886 if name not in self.scope.sources: 887 raise OptimizeError(f"Unknown table: {name}") 888 889 source = self.scope.sources[name] 890 891 if isinstance(source, exp.Table): 892 columns = self.schema.column_names(source, only_visible) 893 elif isinstance(source, Scope) and isinstance( 894 source.expression, (exp.Values, exp.Unnest) 895 ): 896 columns = source.expression.named_selects 897 898 # in bigquery, unnest structs are automatically scoped as tables, so you can 899 # directly select a struct field in a query. 900 # this handles the case where the unnest is statically defined. 901 if self.schema.dialect == "bigquery": 902 if source.expression.is_type(exp.DataType.Type.STRUCT): 903 for k in source.expression.type.expressions: # type: ignore 904 columns.append(k.name) 905 else: 906 columns = source.expression.named_selects 907 908 node, _ = self.scope.selected_sources.get(name) or (None, None) 909 if isinstance(node, Scope): 910 column_aliases = node.expression.alias_column_names 911 elif isinstance(node, exp.Expression): 912 column_aliases = node.alias_column_names 913 else: 914 column_aliases = [] 915 916 if column_aliases: 917 # If the source's columns are aliased, their aliases shadow the corresponding column names. 918 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 919 columns = [ 920 alias or name 921 for (name, alias) in itertools.zip_longest(columns, column_aliases) 922 ] 923 924 self._get_source_columns_cache[cache_key] = columns 925 926 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.