sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import find_new_name, mypyc_attr, seq_get 12from builtins import type as Type 13 14logger = logging.getLogger("sqlglot") 15 16if t.TYPE_CHECKING: 17 from sqlglot._typing import E 18 from collections.abc import Iterator 19 20TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 21 22 23class ScopeType(Enum): 24 ROOT = auto() 25 SUBQUERY = auto() 26 DERIVED_TABLE = auto() 27 CTE = auto() 28 UNION = auto() 29 UDTF = auto() 30 31 32@mypyc_attr(native_class=True) 33class Scope: 34 """ 35 Selection scope. 36 37 Attributes: 38 expression: Root expression of this scope 39 sources: Mapping of source name to either 40 a Table expression or another Scope instance. For example: 41 SELECT * FROM x {"x": Table(this="x")} 42 SELECT * FROM x AS y {"y": Table(this="x")} 43 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 44 lateral_sources: Sources from laterals 45 For example: 46 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 47 The LATERAL VIEW EXPLODE gets x as a source. 48 cte_sources: Sources from CTES 49 outer_columns: If this is a derived table or CTE, and the outer query 50 defines a column list for the alias of this scope, this is that list of columns. 51 For example: 52 SELECT * FROM (SELECT ...) AS y(col1, col2) 53 The inner query would have `["col1", "col2"]` for its `outer_columns` 54 parent: Parent scope 55 scope_type: Type of this scope, relative to it's parent 56 subquery_scopes: List of all child scopes for subqueries 57 cte_scopes: List of all child scopes for CTEs 58 derived_table_scopes: List of all child scopes for derived_tables 59 udtf_scopes: List of all child scopes for user defined tabular functions 60 table_scopes: derived_table_scopes + udtf_scopes, in the order that they're defined 61 union_scopes: If this Scope is for a Union expression, this will be 62 a list of the left and right child scopes. 63 """ 64 65 _collected: bool 66 _raw_columns: list[exp.Column] 67 _table_columns: list[exp.TableColumn] 68 _stars: list[exp.Column | exp.Dot] 69 _derived_tables: list[exp.Subquery] 70 _udtfs: list[exp.UDTF] 71 _tables: list[exp.Table] 72 _ctes: list[exp.CTE] 73 _subqueries: list[exp.Select | exp.SetOperation] 74 _join_hints: list[exp.JoinHint] 75 _semi_anti_join_tables: set[str] 76 _column_index: set[int] 77 _selected_sources: dict[str, tuple[exp.Selectable, exp.Table | Scope]] | None 78 _columns: list[exp.Column] | None 79 _external_columns: list[exp.Column] | None 80 _local_columns: list[exp.Column] | None 81 _pivots: list[exp.Pivot] | None 82 _references: list[tuple[str, exp.Selectable]] | None 83 84 def __init__( 85 self, 86 expression: exp.Expr, 87 sources: dict[str, exp.Table | Scope] | None = None, 88 outer_columns: list[str] | None = None, 89 parent: Scope | None = None, 90 scope_type: ScopeType = ScopeType.ROOT, 91 lateral_sources: dict[str, exp.Table | Scope] | None = None, 92 cte_sources: dict[str, exp.Table | Scope] | None = None, 93 can_be_correlated: bool | None = None, 94 ) -> None: 95 self.expression = expression 96 self.sources = sources or {} 97 self.lateral_sources = lateral_sources or {} 98 self.cte_sources = cte_sources or {} 99 self.sources.update(self.lateral_sources) 100 self.sources.update(self.cte_sources) 101 self.outer_columns = outer_columns or [] 102 self.parent = parent 103 self.scope_type = scope_type 104 self.subquery_scopes: list[Scope] = [] 105 self.derived_table_scopes: list[Scope] = [] 106 self.table_scopes: list[Scope] = [] 107 self.cte_scopes: list[Scope] = [] 108 self.union_scopes: list[Scope] = [] 109 self.udtf_scopes: list[Scope] = [] 110 self.can_be_correlated = can_be_correlated 111 self.clear_cache() 112 113 def clear_cache(self) -> None: 114 self._collected = False 115 self._raw_columns = [] 116 self._table_columns = [] 117 self._stars = [] 118 self._derived_tables = [] 119 self._udtfs = [] 120 self._tables = [] 121 self._ctes = [] 122 self._subqueries = [] 123 self._join_hints = [] 124 self._semi_anti_join_tables = set() 125 self._column_index = set() 126 self._selected_sources = None 127 self._columns = None 128 self._external_columns = None 129 self._local_columns = None 130 self._pivots = None 131 self._references = None 132 133 def branch( 134 self, 135 expression: exp.Expr, 136 scope_type: ScopeType, 137 sources: dict[str, exp.Table | Scope] | None = None, 138 cte_sources: dict[str, exp.Table | Scope] | None = None, 139 lateral_sources: dict[str, exp.Table | Scope] | None = None, 140 outer_columns: list[str] | None = None, 141 ) -> Scope: 142 """Branch from the current scope to a new, inner scope""" 143 return Scope( 144 expression=expression.unnest(), 145 sources=sources.copy() if sources else None, 146 parent=self, 147 scope_type=scope_type, 148 cte_sources={**self.cte_sources, **(cte_sources or {})}, 149 lateral_sources=lateral_sources.copy() if lateral_sources else None, 150 can_be_correlated=self.can_be_correlated 151 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 152 outer_columns=outer_columns, 153 ) 154 155 def _collect(self) -> None: 156 self._tables = [] 157 self._ctes = [] 158 self._subqueries = [] 159 self._derived_tables = [] 160 self._udtfs = [] 161 self._raw_columns = [] 162 self._table_columns = [] 163 self._stars = [] 164 self._join_hints = [] 165 self._semi_anti_join_tables = set() 166 self._column_index = set() 167 168 for node in self.walk(): 169 if node is self.expression: 170 continue 171 172 if isinstance(node, exp.Dot) and node.is_star: 173 self._stars.append(node) 174 elif type(node) is exp.Column: 175 self._column_index.add(id(node)) 176 177 if isinstance(node.this, exp.Star): 178 self._stars.append(node) 179 else: 180 self._raw_columns.append(node) 181 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 182 parent = node.parent 183 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 184 self._semi_anti_join_tables.add(node.alias_or_name) 185 186 self._tables.append(node) 187 elif isinstance(node, exp.JoinHint): 188 self._join_hints.append(node) 189 elif isinstance(node, exp.UDTF): 190 self._udtfs.append(node) 191 elif isinstance(node, exp.CTE): 192 self._ctes.append(node) 193 elif _is_derived_table(node) and _is_from_or_join(node): 194 self._derived_tables.append(t.cast(exp.Subquery, node)) 195 elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node): 196 self._subqueries.append(node) 197 elif isinstance(node, exp.TableColumn): 198 self._table_columns.append(node) 199 200 self._collected = True 201 202 def _ensure_collected(self) -> None: 203 if not self._collected: 204 self._collect() 205 206 def walk(self, prune: t.Callable[[exp.Expr], bool] | None = None) -> Iterator[exp.Expr]: 207 return walk_in_scope(self.expression, prune=prune) 208 209 def find(self, *expression_types: Type[E]) -> E | None: 210 return find_in_scope(self.expression, *expression_types) 211 212 def find_all(self, *expression_types: Type[E]) -> Iterator[E]: 213 return find_all_in_scope(self.expression, *expression_types) 214 215 def replace(self, old: exp.Expr, new: exp.Expr) -> None: 216 """ 217 Replace `old` with `new`. 218 219 This can be used instead of `exp.Expr.replace` to ensure the `Scope` is kept up-to-date. 220 221 Args: 222 old (exp.Expr): old node 223 new (exp.Expr): new node 224 """ 225 old.replace(new) 226 self.clear_cache() 227 228 @property 229 def tables(self) -> list[exp.Table]: 230 """ 231 List of tables in this scope. 232 233 Returns: 234 list[exp.Table]: tables 235 """ 236 self._ensure_collected() 237 return self._tables 238 239 @property 240 def ctes(self) -> list[exp.CTE]: 241 """ 242 List of CTEs in this scope. 243 244 Returns: 245 list[exp.CTE]: ctes 246 """ 247 self._ensure_collected() 248 return self._ctes 249 250 @property 251 def derived_tables(self) -> list[exp.Subquery]: 252 """ 253 List of derived tables in this scope. 254 255 For example: 256 SELECT * FROM (SELECT ...) <- that's a derived table 257 258 Returns: 259 list[exp.Subquery]: derived tables 260 """ 261 self._ensure_collected() 262 return self._derived_tables 263 264 @property 265 def udtfs(self) -> list[exp.UDTF]: 266 """ 267 List of "User Defined Tabular Functions" in this scope. 268 269 Returns: 270 list[exp.UDTF]: UDTFs 271 """ 272 self._ensure_collected() 273 return self._udtfs 274 275 @property 276 def subqueries(self) -> list[exp.Select | exp.SetOperation]: 277 """ 278 List of subqueries in this scope. 279 280 For example: 281 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 282 283 Returns: 284 list[exp.Select | exp.SetOperation]: subqueries 285 """ 286 self._ensure_collected() 287 return self._subqueries 288 289 @property 290 def stars(self) -> list[exp.Column | exp.Dot]: 291 """ 292 List of star expressions (columns or dots) in this scope. 293 """ 294 self._ensure_collected() 295 return self._stars 296 297 @property 298 def column_index(self) -> set[int]: 299 """ 300 Set of column object IDs that belong to this scope's expression. 301 """ 302 self._ensure_collected() 303 return self._column_index 304 305 @property 306 def columns(self) -> list[exp.Column]: 307 """ 308 List of columns in this scope. 309 310 Returns: 311 list[exp.Column]: Column instances in this scope, plus any 312 Columns that reference this scope from correlated subqueries. 313 """ 314 if self._columns is None: 315 self._ensure_collected() 316 columns = self._raw_columns 317 318 external_columns = [ 319 column 320 for scope in itertools.chain( 321 self.subquery_scopes, 322 self.udtf_scopes, 323 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 324 ) 325 for column in scope.external_columns 326 ] 327 328 expr = self.expression 329 named_selects = set(expr.named_selects) if isinstance(expr, exp.Query) else set() 330 331 self._columns = [] 332 for column in columns + external_columns: 333 ancestor = column.find_ancestor( 334 exp.Select, 335 exp.Qualify, 336 exp.Order, 337 exp.Having, 338 exp.Hint, 339 exp.Table, 340 exp.Star, 341 exp.Distinct, 342 ) 343 if ( 344 not ancestor 345 or column.text("table") 346 or isinstance(ancestor, exp.Select) 347 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 348 or ( 349 isinstance(ancestor, (exp.Order, exp.Distinct)) 350 and ( 351 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 352 or not isinstance(ancestor.parent, exp.Select) 353 or column.name not in named_selects 354 ) 355 ) 356 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 357 ): 358 self._columns.append(column) 359 360 return self._columns 361 362 @property 363 def table_columns(self) -> list[exp.TableColumn]: 364 self._ensure_collected() 365 return self._table_columns 366 367 @property 368 def selected_sources(self) -> dict[str, tuple[exp.Selectable, exp.Table | Scope]]: 369 """ 370 Mapping of nodes and sources that are actually selected from in this scope. 371 372 That is, all tables in a schema are selectable at any point. But a 373 table only becomes a selected source if it's included in a FROM or JOIN clause. 374 375 Returns: 376 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 377 """ 378 if self._selected_sources is None: 379 result: dict[str, tuple[exp.Selectable, exp.Table | Scope]] = {} 380 381 for name, node in self.references: 382 if name in self._semi_anti_join_tables: 383 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 384 # selected source 385 continue 386 387 if name in result: 388 raise OptimizeError(f"Alias already used: {name}") 389 if name in self.sources: 390 result[name] = (node, self.sources[name]) 391 392 self._selected_sources = result 393 return self._selected_sources 394 395 @property 396 def references(self) -> list[tuple[str, exp.Selectable]]: 397 if self._references is None: 398 self._references = [] 399 400 for table in self.tables: 401 self._references.append((table.alias_or_name, table)) 402 for _expr in itertools.chain(self.derived_tables, self.udtfs): 403 # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable trait vtable dispatch 404 expression: exp.Expr = _expr 405 self._references.append( 406 ( 407 _get_source_alias(expression), 408 ( 409 expression if expression.args.get("pivots") else expression.unnest() 410 ).assert_is(exp.Selectable), 411 ) 412 ) 413 414 return self._references 415 416 @property 417 def external_columns(self) -> list[exp.Column]: 418 """ 419 Columns that appear to reference sources in outer scopes. 420 421 Returns: 422 list[exp.Column]: Column instances that don't reference sources in the current scope. 423 """ 424 if self._external_columns is None: 425 if isinstance(self.expression, exp.SetOperation): 426 left, right = self.union_scopes 427 self._external_columns = left.external_columns + right.external_columns 428 else: 429 local_source_names = {name for name, _ in self.references} 430 self._external_columns = [ 431 c 432 for c in self.columns 433 if c.text("table") not in local_source_names 434 and c.text("table") not in self.semi_or_anti_join_tables 435 ] 436 437 return self._external_columns 438 439 @property 440 def local_columns(self) -> list[exp.Column]: 441 """ 442 Columns in this scope that are not external. 443 444 Returns: 445 list[exp.Column]: Column instances that reference sources in the current scope. 446 """ 447 if self._local_columns is None: 448 external_columns = set(self.external_columns) 449 self._local_columns = [c for c in self.columns if c not in external_columns] 450 451 return self._local_columns 452 453 @property 454 def unqualified_columns(self) -> list[exp.Column]: 455 """ 456 Unqualified columns in the current scope. 457 458 Returns: 459 list[exp.Column]: Unqualified columns 460 """ 461 return [c for c in self.columns if not c.text("table")] 462 463 @property 464 def join_hints(self) -> list[exp.JoinHint]: 465 """ 466 Hints that exist in the scope that reference tables 467 468 Returns: 469 list[exp.JoinHint]: Join hints that are referenced within the scope 470 """ 471 self._ensure_collected() 472 return self._join_hints 473 474 @property 475 def pivots(self) -> list[exp.Pivot]: 476 if self._pivots is None: 477 self._pivots = [ 478 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 479 ] 480 481 return self._pivots 482 483 @property 484 def semi_or_anti_join_tables(self) -> set[str]: 485 self._ensure_collected() 486 return self._semi_anti_join_tables 487 488 def source_columns(self, source_name: str) -> list[exp.Column]: 489 """ 490 Get all columns in the current scope for a particular source. 491 492 Args: 493 source_name (str): Name of the source 494 Returns: 495 list[exp.Column]: Column instances that reference `source_name` 496 """ 497 return [column for column in self.columns if column.text("table") == source_name] 498 499 @property 500 def is_subquery(self) -> bool: 501 """Determine if this scope is a subquery""" 502 return self.scope_type == ScopeType.SUBQUERY 503 504 @property 505 def is_derived_table(self) -> bool: 506 """Determine if this scope is a derived table""" 507 return self.scope_type == ScopeType.DERIVED_TABLE 508 509 @property 510 def is_union(self) -> bool: 511 """Determine if this scope is a union""" 512 return self.scope_type == ScopeType.UNION 513 514 @property 515 def is_cte(self) -> bool: 516 """Determine if this scope is a common table expression""" 517 return self.scope_type == ScopeType.CTE 518 519 @property 520 def is_root(self) -> bool: 521 """Determine if this is the root scope""" 522 return self.scope_type == ScopeType.ROOT 523 524 @property 525 def is_udtf(self) -> bool: 526 """Determine if this scope is a UDTF (User Defined Table Function)""" 527 return self.scope_type == ScopeType.UDTF 528 529 @property 530 def is_correlated_subquery(self) -> bool: 531 """Determine if this scope is a correlated subquery""" 532 return bool(self.can_be_correlated and self.external_columns) 533 534 def rename_source(self, old_name: str | None, new_name: str) -> None: 535 """Rename a source in this scope""" 536 old_name = old_name or "" 537 if old_name in self.sources: 538 self.sources[new_name] = self.sources.pop(old_name) 539 540 def add_source(self, name: str, source: exp.Table | Scope) -> None: 541 """Add a source to this scope""" 542 self.sources[name] = source 543 self.clear_cache() 544 545 def remove_source(self, name: str) -> None: 546 """Remove a source from this scope""" 547 self.sources.pop(name, None) 548 self.clear_cache() 549 550 def __repr__(self) -> str: 551 return f"Scope<{self.expression.sql()}>" 552 553 def traverse(self) -> Iterator[Scope]: 554 """ 555 Traverse the scope tree from this node. 556 557 Yields: 558 Scope: scope instances in depth-first-search post-order 559 """ 560 stack: list[Scope] = [self] 561 result: list[Scope] = [] 562 while stack: 563 scope = stack.pop() 564 result.append(scope) 565 stack.extend( 566 itertools.chain( 567 scope.cte_scopes, 568 scope.union_scopes, 569 scope.table_scopes, 570 scope.subquery_scopes, 571 ) 572 ) 573 574 yield from reversed(result) 575 576 def ref_count(self) -> dict[int, int]: 577 """ 578 Count the number of times each scope in this tree is referenced. 579 580 Returns: 581 dict[int, int]: Mapping of Scope instance ID to reference count 582 """ 583 scope_ref_count: dict[int, int] = defaultdict(int) 584 585 for scope in self.traverse(): 586 for _, source in scope.selected_sources.values(): 587 scope_ref_count[id(source)] += 1 588 589 for name in scope._semi_anti_join_tables: 590 # semi/anti join sources are not actually selected but we still need to 591 # increment their ref count to avoid them being optimized away 592 if name in scope.sources: 593 scope_ref_count[id(scope.sources[name])] += 1 594 595 return scope_ref_count 596 597 598def traverse_scope(expression: exp.Expr) -> list[Scope]: 599 """ 600 Traverse an expression by its "scopes". 601 602 "Scope" represents the current context of a Select statement. 603 604 This is helpful for optimizing queries, where we need more information than 605 the expression tree itself. For example, we might care about the source 606 names within a subquery. Returns a list because a generator could result in 607 incomplete properties which is confusing. 608 609 Examples: 610 >>> import sqlglot 611 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 612 >>> scopes = traverse_scope(expression) 613 >>> scopes[0].expression.sql(), list(scopes[0].sources) 614 ('SELECT a FROM x', ['x']) 615 >>> scopes[1].expression.sql(), list(scopes[1].sources) 616 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 617 618 Args: 619 expression: Expr to traverse 620 621 Returns: 622 A list of the created scope instances 623 """ 624 if isinstance(expression, TRAVERSABLES): 625 return list(_traverse_scope(Scope(expression))) 626 return [] 627 628 629def build_scope(expression: exp.Expr) -> Scope | None: 630 """ 631 Build a scope tree. 632 633 Args: 634 expression: Expr to build the scope tree for. 635 636 Returns: 637 The root scope 638 """ 639 return seq_get(traverse_scope(expression), -1) 640 641 642def _traverse_scope(scope: Scope) -> Iterator[Scope]: 643 expression = scope.expression 644 645 if isinstance(expression, exp.Select): 646 yield from _traverse_select(scope) 647 elif isinstance(expression, exp.SetOperation): 648 yield from _traverse_ctes(scope) 649 yield from _traverse_union(scope) 650 return 651 elif isinstance(expression, exp.Subquery): 652 if scope.is_root: 653 yield from _traverse_select(scope) 654 else: 655 yield from _traverse_subqueries(scope) 656 elif isinstance(expression, exp.Table): 657 yield from _traverse_tables(scope) 658 elif isinstance(expression, exp.UDTF): 659 yield from _traverse_udtfs(scope) 660 elif isinstance(expression, exp.DDL): 661 # TODO (mypyc): change to ddl_expression = expression.expression 662 ddl_expression = expression.args.get("expression") 663 if isinstance(ddl_expression, exp.Query): 664 yield from _traverse_ctes(scope) 665 yield from _traverse_scope(Scope(ddl_expression, cte_sources=scope.cte_sources)) 666 return 667 elif isinstance(expression, exp.DML): 668 yield from _traverse_ctes(scope) 669 for query in find_all_in_scope(expression, exp.Query): 670 # This check ensures we don't yield the CTE/nested queries twice 671 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 672 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 673 return 674 else: 675 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 676 return 677 678 yield scope 679 680 681def _traverse_select(scope: Scope) -> Iterator[Scope]: 682 yield from _traverse_ctes(scope) 683 yield from _traverse_tables(scope) 684 yield from _traverse_subqueries(scope) 685 686 687def _traverse_union(scope: Scope) -> Iterator[Scope]: 688 prev_scope: Scope | None = None 689 union_scope_stack: list[Scope] = [scope] 690 691 set_op = scope.expression 692 assert isinstance(set_op, exp.SetOperation) 693 expression_stack: list[exp.Expr] = [set_op.right, set_op.left] 694 695 while expression_stack: 696 expression = expression_stack.pop() 697 union_scope = union_scope_stack[-1] 698 699 new_scope = union_scope.branch( 700 expression, 701 outer_columns=union_scope.outer_columns, 702 scope_type=ScopeType.UNION, 703 ) 704 705 if isinstance(expression, exp.SetOperation): 706 yield from _traverse_ctes(new_scope) 707 708 union_scope_stack.append(new_scope) 709 expression_stack.extend([expression.right, expression.left]) 710 continue 711 712 for scope in _traverse_scope(new_scope): 713 yield scope 714 715 if prev_scope: 716 union_scope_stack.pop() 717 union_scope.union_scopes = [prev_scope, scope] 718 prev_scope = union_scope 719 720 yield union_scope 721 else: 722 prev_scope = scope 723 724 725def _traverse_ctes(scope: Scope) -> Iterator[Scope]: 726 sources: dict[str, exp.Table | Scope] = {} 727 728 for cte in scope.ctes: 729 cte_name = cte.alias 730 731 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 732 # thus the recursive scope is the first section of the union. 733 with_ = scope.expression.args.get("with_") 734 if with_ and with_.recursive: 735 union = cte.this 736 737 if isinstance(union, exp.SetOperation): 738 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 739 740 child_scope: Scope | None = None 741 742 for child_scope in _traverse_scope( 743 scope.branch( 744 cte.this, 745 cte_sources=sources, 746 outer_columns=cte.alias_column_names, 747 scope_type=ScopeType.CTE, 748 ) 749 ): 750 yield child_scope 751 752 # append the final child_scope yielded 753 if child_scope: 754 sources[cte_name] = child_scope 755 scope.cte_scopes.append(child_scope) 756 757 scope.sources.update(sources) 758 scope.cte_sources.update(sources) 759 760 761def _is_derived_table(expression: exp.Expr) -> bool: 762 """ 763 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 764 as it doesn't introduce a new scope. If an alias is present, it shadows all names 765 under the Subquery, so that's one exception to this rule. 766 """ 767 return isinstance(expression, exp.Subquery) and bool( 768 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 769 ) 770 771 772def _is_from_or_join(expression: exp.Expr) -> bool: 773 """ 774 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 775 """ 776 parent = expression.parent 777 778 # Subqueries can be arbitrarily nested 779 while type(parent) is exp.Subquery: 780 parent = parent.parent 781 782 return type(parent) in (exp.From, exp.Join) 783 784 785def _traverse_tables(scope: Scope) -> Iterator[Scope]: 786 sources: dict[str, exp.Table | Scope] = {} 787 788 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 789 expressions: list[exp.Expr] = [] 790 from_ = scope.expression.args.get("from_") 791 if from_: 792 expressions.append(from_.this) 793 794 for join in scope.expression.args.get("joins") or []: 795 expressions.append(join.this) 796 797 if isinstance(scope.expression, exp.Table): 798 expressions.append(scope.expression) 799 800 expressions.extend(scope.expression.args.get("laterals") or []) 801 802 for expression in expressions: 803 if isinstance(expression, exp.Final): 804 expression = expression.this 805 if isinstance(expression, exp.Table): 806 table_name = expression.name 807 source_name = expression.alias_or_name 808 809 if table_name in scope.sources and not expression.db: 810 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 811 # it is pivoted, because then we get back a new table and hence a new source. 812 pivots = expression.args.get("pivots") 813 if pivots: 814 sources[pivots[0].alias] = expression 815 else: 816 sources[source_name] = scope.sources[table_name] 817 elif source_name in sources: 818 sources[find_new_name(sources, table_name)] = expression 819 else: 820 sources[source_name] = expression 821 822 # Make sure to not include the joins twice 823 if expression is not scope.expression: 824 expressions.extend(join.this for join in expression.args.get("joins") or []) 825 826 continue 827 828 if not isinstance(expression, exp.DerivedTable): 829 continue 830 831 # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable/UDTF trait vtable dispatch 832 node: exp.Expr = expression 833 834 if isinstance(expression, exp.UDTF): 835 lateral_sources = sources 836 scope_type = ScopeType.UDTF 837 scopes = scope.udtf_scopes 838 elif _is_derived_table(expression): 839 lateral_sources = None 840 scope_type = ScopeType.DERIVED_TABLE 841 scopes = scope.derived_table_scopes 842 expressions.extend(join.this for join in node.args.get("joins") or []) 843 else: 844 # Makes sure we check for possible sources in nested table constructs 845 expressions.append(node.this) 846 expressions.extend(join.this for join in node.args.get("joins") or []) 847 continue 848 849 child_scope: Scope | None = None 850 851 for child_scope in _traverse_scope( 852 scope.branch( 853 node, 854 lateral_sources=lateral_sources, 855 outer_columns=node.alias_column_names, 856 scope_type=scope_type, 857 ) 858 ): 859 yield child_scope 860 861 # Tables without aliases will be set as "" 862 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 863 # Until then, this means that only a single, unaliased derived table is allowed (rather, 864 # the latest one wins. 865 sources[_get_source_alias(node)] = child_scope 866 867 # append the final child_scope yielded 868 if child_scope: 869 scopes.append(child_scope) 870 scope.table_scopes.append(child_scope) 871 872 scope.sources.update(sources) 873 874 875def _traverse_subqueries(scope: Scope) -> Iterator[Scope]: 876 for subquery in scope.subqueries: 877 top: Scope | None = None 878 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 879 yield child_scope 880 top = child_scope 881 if top is not None: 882 scope.subquery_scopes.append(top) 883 884 885def _traverse_udtfs(scope: Scope) -> Iterator[Scope]: 886 if isinstance(scope.expression, exp.Unnest): 887 udtf_expressions = scope.expression.expressions 888 elif isinstance(scope.expression, exp.Lateral): 889 udtf_expressions = [scope.expression.this] 890 else: 891 udtf_expressions = [] 892 893 sources: dict[str, exp.Table | Scope] = {} 894 for expression in udtf_expressions: 895 if isinstance(expression, exp.Subquery): 896 top: Scope | None = None 897 for child_scope in _traverse_scope( 898 scope.branch( 899 expression, 900 scope_type=ScopeType.SUBQUERY, 901 outer_columns=expression.alias_column_names, 902 ) 903 ): 904 yield child_scope 905 top = child_scope 906 sources[_get_source_alias(expression)] = child_scope 907 908 if top is not None: 909 scope.subquery_scopes.append(top) 910 911 scope.sources.update(sources) 912 913 914def walk_in_scope( 915 expression: exp.Expr, 916 prune: t.Callable[[exp.Expr], bool] | None = None, 917) -> Iterator[exp.Expr]: 918 """ 919 Returns a generator object which visits all nodes in the syntrax tree, stopping at 920 nodes that start child scopes. This does a custom DFS traversal rather than using 921 expression.walk() because the nested generators aren't optimized in mypyc. 922 923 Args: 924 expression: 925 prune: callable that returns True if 926 the generator should stop traversing this branch of the tree. 927 928 Yields: 929 exp.Expr: each node in scope 930 """ 931 stack: list[exp.Expr] = [expression] 932 933 while stack: 934 node = stack.pop() 935 936 yield node 937 938 if node is not expression and ( 939 isinstance(node, exp.CTE) 940 or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node)) 941 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 942 or isinstance(node, exp.UNWRAPPED_QUERIES) 943 ): 944 if isinstance(node, (exp.Subquery, exp.UDTF)): 945 for key in ("joins", "laterals", "pivots"): 946 for arg in node.args.get(key) or []: 947 yield from walk_in_scope(arg) 948 continue 949 950 if prune and prune(node): 951 continue 952 953 for vs in reversed(node.args.values()): 954 if isinstance(vs, list): 955 for v in reversed(vs): 956 if isinstance(v, exp.Expr): 957 stack.append(v) 958 elif isinstance(vs, exp.Expr): 959 stack.append(vs) 960 961 962def find_all_in_scope( 963 expression: exp.Expr, 964 *expression_types: Type[E], 965) -> Iterator[E]: 966 """ 967 Returns a generator object which visits all nodes in this scope and only yields those that 968 match at least one of the specified expression types. 969 970 This does NOT traverse into subscopes. 971 972 Args: 973 expression: the expression to search. 974 expression_types: the expression type(s) to match. 975 976 Yields: 977 The matching nodes. 978 """ 979 for node in walk_in_scope(expression): 980 if isinstance(node, expression_types): 981 yield node 982 983 984def find_in_scope( 985 expression: exp.Expr, 986 *expression_types: Type[E], 987) -> E | None: 988 """ 989 Returns the first node in this scope which matches at least one of the specified types. 990 991 This does NOT traverse into subscopes. 992 993 Args: 994 expression: the expression to search. 995 expression_types: the expression type(s) to match. 996 997 Returns: 998 The node which matches the criteria or None if no node matching 999 the criteria was found. 1000 """ 1001 return next(find_all_in_scope(expression, *expression_types), None) 1002 1003 1004def _get_source_alias(expression: exp.Expr) -> str: 1005 alias_arg = expression.args.get("alias") 1006 alias_name = expression.alias 1007 1008 if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1: 1009 alias_name = alias_arg.columns[0].name 1010 1011 return alias_name
24class ScopeType(Enum): 25 ROOT = auto() 26 SUBQUERY = auto() 27 DERIVED_TABLE = auto() 28 CTE = auto() 29 UNION = auto() 30 UDTF = auto()
An enumeration.
33@mypyc_attr(native_class=True) 34class Scope: 35 """ 36 Selection scope. 37 38 Attributes: 39 expression: Root expression of this scope 40 sources: Mapping of source name to either 41 a Table expression or another Scope instance. For example: 42 SELECT * FROM x {"x": Table(this="x")} 43 SELECT * FROM x AS y {"y": Table(this="x")} 44 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 45 lateral_sources: Sources from laterals 46 For example: 47 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 48 The LATERAL VIEW EXPLODE gets x as a source. 49 cte_sources: Sources from CTES 50 outer_columns: If this is a derived table or CTE, and the outer query 51 defines a column list for the alias of this scope, this is that list of columns. 52 For example: 53 SELECT * FROM (SELECT ...) AS y(col1, col2) 54 The inner query would have `["col1", "col2"]` for its `outer_columns` 55 parent: Parent scope 56 scope_type: Type of this scope, relative to it's parent 57 subquery_scopes: List of all child scopes for subqueries 58 cte_scopes: List of all child scopes for CTEs 59 derived_table_scopes: List of all child scopes for derived_tables 60 udtf_scopes: List of all child scopes for user defined tabular functions 61 table_scopes: derived_table_scopes + udtf_scopes, in the order that they're defined 62 union_scopes: If this Scope is for a Union expression, this will be 63 a list of the left and right child scopes. 64 """ 65 66 _collected: bool 67 _raw_columns: list[exp.Column] 68 _table_columns: list[exp.TableColumn] 69 _stars: list[exp.Column | exp.Dot] 70 _derived_tables: list[exp.Subquery] 71 _udtfs: list[exp.UDTF] 72 _tables: list[exp.Table] 73 _ctes: list[exp.CTE] 74 _subqueries: list[exp.Select | exp.SetOperation] 75 _join_hints: list[exp.JoinHint] 76 _semi_anti_join_tables: set[str] 77 _column_index: set[int] 78 _selected_sources: dict[str, tuple[exp.Selectable, exp.Table | Scope]] | None 79 _columns: list[exp.Column] | None 80 _external_columns: list[exp.Column] | None 81 _local_columns: list[exp.Column] | None 82 _pivots: list[exp.Pivot] | None 83 _references: list[tuple[str, exp.Selectable]] | None 84 85 def __init__( 86 self, 87 expression: exp.Expr, 88 sources: dict[str, exp.Table | Scope] | None = None, 89 outer_columns: list[str] | None = None, 90 parent: Scope | None = None, 91 scope_type: ScopeType = ScopeType.ROOT, 92 lateral_sources: dict[str, exp.Table | Scope] | None = None, 93 cte_sources: dict[str, exp.Table | Scope] | None = None, 94 can_be_correlated: bool | None = None, 95 ) -> None: 96 self.expression = expression 97 self.sources = sources or {} 98 self.lateral_sources = lateral_sources or {} 99 self.cte_sources = cte_sources or {} 100 self.sources.update(self.lateral_sources) 101 self.sources.update(self.cte_sources) 102 self.outer_columns = outer_columns or [] 103 self.parent = parent 104 self.scope_type = scope_type 105 self.subquery_scopes: list[Scope] = [] 106 self.derived_table_scopes: list[Scope] = [] 107 self.table_scopes: list[Scope] = [] 108 self.cte_scopes: list[Scope] = [] 109 self.union_scopes: list[Scope] = [] 110 self.udtf_scopes: list[Scope] = [] 111 self.can_be_correlated = can_be_correlated 112 self.clear_cache() 113 114 def clear_cache(self) -> None: 115 self._collected = False 116 self._raw_columns = [] 117 self._table_columns = [] 118 self._stars = [] 119 self._derived_tables = [] 120 self._udtfs = [] 121 self._tables = [] 122 self._ctes = [] 123 self._subqueries = [] 124 self._join_hints = [] 125 self._semi_anti_join_tables = set() 126 self._column_index = set() 127 self._selected_sources = None 128 self._columns = None 129 self._external_columns = None 130 self._local_columns = None 131 self._pivots = None 132 self._references = None 133 134 def branch( 135 self, 136 expression: exp.Expr, 137 scope_type: ScopeType, 138 sources: dict[str, exp.Table | Scope] | None = None, 139 cte_sources: dict[str, exp.Table | Scope] | None = None, 140 lateral_sources: dict[str, exp.Table | Scope] | None = None, 141 outer_columns: list[str] | None = None, 142 ) -> Scope: 143 """Branch from the current scope to a new, inner scope""" 144 return Scope( 145 expression=expression.unnest(), 146 sources=sources.copy() if sources else None, 147 parent=self, 148 scope_type=scope_type, 149 cte_sources={**self.cte_sources, **(cte_sources or {})}, 150 lateral_sources=lateral_sources.copy() if lateral_sources else None, 151 can_be_correlated=self.can_be_correlated 152 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 153 outer_columns=outer_columns, 154 ) 155 156 def _collect(self) -> None: 157 self._tables = [] 158 self._ctes = [] 159 self._subqueries = [] 160 self._derived_tables = [] 161 self._udtfs = [] 162 self._raw_columns = [] 163 self._table_columns = [] 164 self._stars = [] 165 self._join_hints = [] 166 self._semi_anti_join_tables = set() 167 self._column_index = set() 168 169 for node in self.walk(): 170 if node is self.expression: 171 continue 172 173 if isinstance(node, exp.Dot) and node.is_star: 174 self._stars.append(node) 175 elif type(node) is exp.Column: 176 self._column_index.add(id(node)) 177 178 if isinstance(node.this, exp.Star): 179 self._stars.append(node) 180 else: 181 self._raw_columns.append(node) 182 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 183 parent = node.parent 184 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 185 self._semi_anti_join_tables.add(node.alias_or_name) 186 187 self._tables.append(node) 188 elif isinstance(node, exp.JoinHint): 189 self._join_hints.append(node) 190 elif isinstance(node, exp.UDTF): 191 self._udtfs.append(node) 192 elif isinstance(node, exp.CTE): 193 self._ctes.append(node) 194 elif _is_derived_table(node) and _is_from_or_join(node): 195 self._derived_tables.append(t.cast(exp.Subquery, node)) 196 elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node): 197 self._subqueries.append(node) 198 elif isinstance(node, exp.TableColumn): 199 self._table_columns.append(node) 200 201 self._collected = True 202 203 def _ensure_collected(self) -> None: 204 if not self._collected: 205 self._collect() 206 207 def walk(self, prune: t.Callable[[exp.Expr], bool] | None = None) -> Iterator[exp.Expr]: 208 return walk_in_scope(self.expression, prune=prune) 209 210 def find(self, *expression_types: Type[E]) -> E | None: 211 return find_in_scope(self.expression, *expression_types) 212 213 def find_all(self, *expression_types: Type[E]) -> Iterator[E]: 214 return find_all_in_scope(self.expression, *expression_types) 215 216 def replace(self, old: exp.Expr, new: exp.Expr) -> None: 217 """ 218 Replace `old` with `new`. 219 220 This can be used instead of `exp.Expr.replace` to ensure the `Scope` is kept up-to-date. 221 222 Args: 223 old (exp.Expr): old node 224 new (exp.Expr): new node 225 """ 226 old.replace(new) 227 self.clear_cache() 228 229 @property 230 def tables(self) -> list[exp.Table]: 231 """ 232 List of tables in this scope. 233 234 Returns: 235 list[exp.Table]: tables 236 """ 237 self._ensure_collected() 238 return self._tables 239 240 @property 241 def ctes(self) -> list[exp.CTE]: 242 """ 243 List of CTEs in this scope. 244 245 Returns: 246 list[exp.CTE]: ctes 247 """ 248 self._ensure_collected() 249 return self._ctes 250 251 @property 252 def derived_tables(self) -> list[exp.Subquery]: 253 """ 254 List of derived tables in this scope. 255 256 For example: 257 SELECT * FROM (SELECT ...) <- that's a derived table 258 259 Returns: 260 list[exp.Subquery]: derived tables 261 """ 262 self._ensure_collected() 263 return self._derived_tables 264 265 @property 266 def udtfs(self) -> list[exp.UDTF]: 267 """ 268 List of "User Defined Tabular Functions" in this scope. 269 270 Returns: 271 list[exp.UDTF]: UDTFs 272 """ 273 self._ensure_collected() 274 return self._udtfs 275 276 @property 277 def subqueries(self) -> list[exp.Select | exp.SetOperation]: 278 """ 279 List of subqueries in this scope. 280 281 For example: 282 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 283 284 Returns: 285 list[exp.Select | exp.SetOperation]: subqueries 286 """ 287 self._ensure_collected() 288 return self._subqueries 289 290 @property 291 def stars(self) -> list[exp.Column | exp.Dot]: 292 """ 293 List of star expressions (columns or dots) in this scope. 294 """ 295 self._ensure_collected() 296 return self._stars 297 298 @property 299 def column_index(self) -> set[int]: 300 """ 301 Set of column object IDs that belong to this scope's expression. 302 """ 303 self._ensure_collected() 304 return self._column_index 305 306 @property 307 def columns(self) -> list[exp.Column]: 308 """ 309 List of columns in this scope. 310 311 Returns: 312 list[exp.Column]: Column instances in this scope, plus any 313 Columns that reference this scope from correlated subqueries. 314 """ 315 if self._columns is None: 316 self._ensure_collected() 317 columns = self._raw_columns 318 319 external_columns = [ 320 column 321 for scope in itertools.chain( 322 self.subquery_scopes, 323 self.udtf_scopes, 324 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 325 ) 326 for column in scope.external_columns 327 ] 328 329 expr = self.expression 330 named_selects = set(expr.named_selects) if isinstance(expr, exp.Query) else set() 331 332 self._columns = [] 333 for column in columns + external_columns: 334 ancestor = column.find_ancestor( 335 exp.Select, 336 exp.Qualify, 337 exp.Order, 338 exp.Having, 339 exp.Hint, 340 exp.Table, 341 exp.Star, 342 exp.Distinct, 343 ) 344 if ( 345 not ancestor 346 or column.text("table") 347 or isinstance(ancestor, exp.Select) 348 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 349 or ( 350 isinstance(ancestor, (exp.Order, exp.Distinct)) 351 and ( 352 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 353 or not isinstance(ancestor.parent, exp.Select) 354 or column.name not in named_selects 355 ) 356 ) 357 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 358 ): 359 self._columns.append(column) 360 361 return self._columns 362 363 @property 364 def table_columns(self) -> list[exp.TableColumn]: 365 self._ensure_collected() 366 return self._table_columns 367 368 @property 369 def selected_sources(self) -> dict[str, tuple[exp.Selectable, exp.Table | Scope]]: 370 """ 371 Mapping of nodes and sources that are actually selected from in this scope. 372 373 That is, all tables in a schema are selectable at any point. But a 374 table only becomes a selected source if it's included in a FROM or JOIN clause. 375 376 Returns: 377 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 378 """ 379 if self._selected_sources is None: 380 result: dict[str, tuple[exp.Selectable, exp.Table | Scope]] = {} 381 382 for name, node in self.references: 383 if name in self._semi_anti_join_tables: 384 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 385 # selected source 386 continue 387 388 if name in result: 389 raise OptimizeError(f"Alias already used: {name}") 390 if name in self.sources: 391 result[name] = (node, self.sources[name]) 392 393 self._selected_sources = result 394 return self._selected_sources 395 396 @property 397 def references(self) -> list[tuple[str, exp.Selectable]]: 398 if self._references is None: 399 self._references = [] 400 401 for table in self.tables: 402 self._references.append((table.alias_or_name, table)) 403 for _expr in itertools.chain(self.derived_tables, self.udtfs): 404 # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable trait vtable dispatch 405 expression: exp.Expr = _expr 406 self._references.append( 407 ( 408 _get_source_alias(expression), 409 ( 410 expression if expression.args.get("pivots") else expression.unnest() 411 ).assert_is(exp.Selectable), 412 ) 413 ) 414 415 return self._references 416 417 @property 418 def external_columns(self) -> list[exp.Column]: 419 """ 420 Columns that appear to reference sources in outer scopes. 421 422 Returns: 423 list[exp.Column]: Column instances that don't reference sources in the current scope. 424 """ 425 if self._external_columns is None: 426 if isinstance(self.expression, exp.SetOperation): 427 left, right = self.union_scopes 428 self._external_columns = left.external_columns + right.external_columns 429 else: 430 local_source_names = {name for name, _ in self.references} 431 self._external_columns = [ 432 c 433 for c in self.columns 434 if c.text("table") not in local_source_names 435 and c.text("table") not in self.semi_or_anti_join_tables 436 ] 437 438 return self._external_columns 439 440 @property 441 def local_columns(self) -> list[exp.Column]: 442 """ 443 Columns in this scope that are not external. 444 445 Returns: 446 list[exp.Column]: Column instances that reference sources in the current scope. 447 """ 448 if self._local_columns is None: 449 external_columns = set(self.external_columns) 450 self._local_columns = [c for c in self.columns if c not in external_columns] 451 452 return self._local_columns 453 454 @property 455 def unqualified_columns(self) -> list[exp.Column]: 456 """ 457 Unqualified columns in the current scope. 458 459 Returns: 460 list[exp.Column]: Unqualified columns 461 """ 462 return [c for c in self.columns if not c.text("table")] 463 464 @property 465 def join_hints(self) -> list[exp.JoinHint]: 466 """ 467 Hints that exist in the scope that reference tables 468 469 Returns: 470 list[exp.JoinHint]: Join hints that are referenced within the scope 471 """ 472 self._ensure_collected() 473 return self._join_hints 474 475 @property 476 def pivots(self) -> list[exp.Pivot]: 477 if self._pivots is None: 478 self._pivots = [ 479 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 480 ] 481 482 return self._pivots 483 484 @property 485 def semi_or_anti_join_tables(self) -> set[str]: 486 self._ensure_collected() 487 return self._semi_anti_join_tables 488 489 def source_columns(self, source_name: str) -> list[exp.Column]: 490 """ 491 Get all columns in the current scope for a particular source. 492 493 Args: 494 source_name (str): Name of the source 495 Returns: 496 list[exp.Column]: Column instances that reference `source_name` 497 """ 498 return [column for column in self.columns if column.text("table") == source_name] 499 500 @property 501 def is_subquery(self) -> bool: 502 """Determine if this scope is a subquery""" 503 return self.scope_type == ScopeType.SUBQUERY 504 505 @property 506 def is_derived_table(self) -> bool: 507 """Determine if this scope is a derived table""" 508 return self.scope_type == ScopeType.DERIVED_TABLE 509 510 @property 511 def is_union(self) -> bool: 512 """Determine if this scope is a union""" 513 return self.scope_type == ScopeType.UNION 514 515 @property 516 def is_cte(self) -> bool: 517 """Determine if this scope is a common table expression""" 518 return self.scope_type == ScopeType.CTE 519 520 @property 521 def is_root(self) -> bool: 522 """Determine if this is the root scope""" 523 return self.scope_type == ScopeType.ROOT 524 525 @property 526 def is_udtf(self) -> bool: 527 """Determine if this scope is a UDTF (User Defined Table Function)""" 528 return self.scope_type == ScopeType.UDTF 529 530 @property 531 def is_correlated_subquery(self) -> bool: 532 """Determine if this scope is a correlated subquery""" 533 return bool(self.can_be_correlated and self.external_columns) 534 535 def rename_source(self, old_name: str | None, new_name: str) -> None: 536 """Rename a source in this scope""" 537 old_name = old_name or "" 538 if old_name in self.sources: 539 self.sources[new_name] = self.sources.pop(old_name) 540 541 def add_source(self, name: str, source: exp.Table | Scope) -> None: 542 """Add a source to this scope""" 543 self.sources[name] = source 544 self.clear_cache() 545 546 def remove_source(self, name: str) -> None: 547 """Remove a source from this scope""" 548 self.sources.pop(name, None) 549 self.clear_cache() 550 551 def __repr__(self) -> str: 552 return f"Scope<{self.expression.sql()}>" 553 554 def traverse(self) -> Iterator[Scope]: 555 """ 556 Traverse the scope tree from this node. 557 558 Yields: 559 Scope: scope instances in depth-first-search post-order 560 """ 561 stack: list[Scope] = [self] 562 result: list[Scope] = [] 563 while stack: 564 scope = stack.pop() 565 result.append(scope) 566 stack.extend( 567 itertools.chain( 568 scope.cte_scopes, 569 scope.union_scopes, 570 scope.table_scopes, 571 scope.subquery_scopes, 572 ) 573 ) 574 575 yield from reversed(result) 576 577 def ref_count(self) -> dict[int, int]: 578 """ 579 Count the number of times each scope in this tree is referenced. 580 581 Returns: 582 dict[int, int]: Mapping of Scope instance ID to reference count 583 """ 584 scope_ref_count: dict[int, int] = defaultdict(int) 585 586 for scope in self.traverse(): 587 for _, source in scope.selected_sources.values(): 588 scope_ref_count[id(source)] += 1 589 590 for name in scope._semi_anti_join_tables: 591 # semi/anti join sources are not actually selected but we still need to 592 # increment their ref count to avoid them being optimized away 593 if name in scope.sources: 594 scope_ref_count[id(scope.sources[name])] += 1 595 596 return scope_ref_count
Selection scope.
Attributes:
- expression: Root expression of this scope
- sources: Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources: Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources: Sources from CTES
- outer_columns: If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]for itsouter_columns - parent: Parent scope
- scope_type: Type of this scope, relative to it's parent
- subquery_scopes: List of all child scopes for subqueries
- cte_scopes: List of all child scopes for CTEs
- derived_table_scopes: List of all child scopes for derived_tables
- udtf_scopes: List of all child scopes for user defined tabular functions
- table_scopes: derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes: If this Scope is for a Union expression, this will be a list of the left and right child scopes.
85 def __init__( 86 self, 87 expression: exp.Expr, 88 sources: dict[str, exp.Table | Scope] | None = None, 89 outer_columns: list[str] | None = None, 90 parent: Scope | None = None, 91 scope_type: ScopeType = ScopeType.ROOT, 92 lateral_sources: dict[str, exp.Table | Scope] | None = None, 93 cte_sources: dict[str, exp.Table | Scope] | None = None, 94 can_be_correlated: bool | None = None, 95 ) -> None: 96 self.expression = expression 97 self.sources = sources or {} 98 self.lateral_sources = lateral_sources or {} 99 self.cte_sources = cte_sources or {} 100 self.sources.update(self.lateral_sources) 101 self.sources.update(self.cte_sources) 102 self.outer_columns = outer_columns or [] 103 self.parent = parent 104 self.scope_type = scope_type 105 self.subquery_scopes: list[Scope] = [] 106 self.derived_table_scopes: list[Scope] = [] 107 self.table_scopes: list[Scope] = [] 108 self.cte_scopes: list[Scope] = [] 109 self.union_scopes: list[Scope] = [] 110 self.udtf_scopes: list[Scope] = [] 111 self.can_be_correlated = can_be_correlated 112 self.clear_cache()
114 def clear_cache(self) -> None: 115 self._collected = False 116 self._raw_columns = [] 117 self._table_columns = [] 118 self._stars = [] 119 self._derived_tables = [] 120 self._udtfs = [] 121 self._tables = [] 122 self._ctes = [] 123 self._subqueries = [] 124 self._join_hints = [] 125 self._semi_anti_join_tables = set() 126 self._column_index = set() 127 self._selected_sources = None 128 self._columns = None 129 self._external_columns = None 130 self._local_columns = None 131 self._pivots = None 132 self._references = None
134 def branch( 135 self, 136 expression: exp.Expr, 137 scope_type: ScopeType, 138 sources: dict[str, exp.Table | Scope] | None = None, 139 cte_sources: dict[str, exp.Table | Scope] | None = None, 140 lateral_sources: dict[str, exp.Table | Scope] | None = None, 141 outer_columns: list[str] | None = None, 142 ) -> Scope: 143 """Branch from the current scope to a new, inner scope""" 144 return Scope( 145 expression=expression.unnest(), 146 sources=sources.copy() if sources else None, 147 parent=self, 148 scope_type=scope_type, 149 cte_sources={**self.cte_sources, **(cte_sources or {})}, 150 lateral_sources=lateral_sources.copy() if lateral_sources else None, 151 can_be_correlated=self.can_be_correlated 152 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 153 outer_columns=outer_columns, 154 )
Branch from the current scope to a new, inner scope
216 def replace(self, old: exp.Expr, new: exp.Expr) -> None: 217 """ 218 Replace `old` with `new`. 219 220 This can be used instead of `exp.Expr.replace` to ensure the `Scope` is kept up-to-date. 221 222 Args: 223 old (exp.Expr): old node 224 new (exp.Expr): new node 225 """ 226 old.replace(new) 227 self.clear_cache()
Replace old with new.
This can be used instead of exp.Expr.replace to ensure the Scope is kept up-to-date.
Arguments:
- old (exp.Expr): old node
- new (exp.Expr): new node
229 @property 230 def tables(self) -> list[exp.Table]: 231 """ 232 List of tables in this scope. 233 234 Returns: 235 list[exp.Table]: tables 236 """ 237 self._ensure_collected() 238 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
240 @property 241 def ctes(self) -> list[exp.CTE]: 242 """ 243 List of CTEs in this scope. 244 245 Returns: 246 list[exp.CTE]: ctes 247 """ 248 self._ensure_collected() 249 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
251 @property 252 def derived_tables(self) -> list[exp.Subquery]: 253 """ 254 List of derived tables in this scope. 255 256 For example: 257 SELECT * FROM (SELECT ...) <- that's a derived table 258 259 Returns: 260 list[exp.Subquery]: derived tables 261 """ 262 self._ensure_collected() 263 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
265 @property 266 def udtfs(self) -> list[exp.UDTF]: 267 """ 268 List of "User Defined Tabular Functions" in this scope. 269 270 Returns: 271 list[exp.UDTF]: UDTFs 272 """ 273 self._ensure_collected() 274 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
276 @property 277 def subqueries(self) -> list[exp.Select | exp.SetOperation]: 278 """ 279 List of subqueries in this scope. 280 281 For example: 282 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 283 284 Returns: 285 list[exp.Select | exp.SetOperation]: subqueries 286 """ 287 self._ensure_collected() 288 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
290 @property 291 def stars(self) -> list[exp.Column | exp.Dot]: 292 """ 293 List of star expressions (columns or dots) in this scope. 294 """ 295 self._ensure_collected() 296 return self._stars
List of star expressions (columns or dots) in this scope.
298 @property 299 def column_index(self) -> set[int]: 300 """ 301 Set of column object IDs that belong to this scope's expression. 302 """ 303 self._ensure_collected() 304 return self._column_index
Set of column object IDs that belong to this scope's expression.
306 @property 307 def columns(self) -> list[exp.Column]: 308 """ 309 List of columns in this scope. 310 311 Returns: 312 list[exp.Column]: Column instances in this scope, plus any 313 Columns that reference this scope from correlated subqueries. 314 """ 315 if self._columns is None: 316 self._ensure_collected() 317 columns = self._raw_columns 318 319 external_columns = [ 320 column 321 for scope in itertools.chain( 322 self.subquery_scopes, 323 self.udtf_scopes, 324 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 325 ) 326 for column in scope.external_columns 327 ] 328 329 expr = self.expression 330 named_selects = set(expr.named_selects) if isinstance(expr, exp.Query) else set() 331 332 self._columns = [] 333 for column in columns + external_columns: 334 ancestor = column.find_ancestor( 335 exp.Select, 336 exp.Qualify, 337 exp.Order, 338 exp.Having, 339 exp.Hint, 340 exp.Table, 341 exp.Star, 342 exp.Distinct, 343 ) 344 if ( 345 not ancestor 346 or column.text("table") 347 or isinstance(ancestor, exp.Select) 348 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 349 or ( 350 isinstance(ancestor, (exp.Order, exp.Distinct)) 351 and ( 352 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 353 or not isinstance(ancestor.parent, exp.Select) 354 or column.name not in named_selects 355 ) 356 ) 357 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 358 ): 359 self._columns.append(column) 360 361 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
368 @property 369 def selected_sources(self) -> dict[str, tuple[exp.Selectable, exp.Table | Scope]]: 370 """ 371 Mapping of nodes and sources that are actually selected from in this scope. 372 373 That is, all tables in a schema are selectable at any point. But a 374 table only becomes a selected source if it's included in a FROM or JOIN clause. 375 376 Returns: 377 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 378 """ 379 if self._selected_sources is None: 380 result: dict[str, tuple[exp.Selectable, exp.Table | Scope]] = {} 381 382 for name, node in self.references: 383 if name in self._semi_anti_join_tables: 384 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 385 # selected source 386 continue 387 388 if name in result: 389 raise OptimizeError(f"Alias already used: {name}") 390 if name in self.sources: 391 result[name] = (node, self.sources[name]) 392 393 self._selected_sources = result 394 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
396 @property 397 def references(self) -> list[tuple[str, exp.Selectable]]: 398 if self._references is None: 399 self._references = [] 400 401 for table in self.tables: 402 self._references.append((table.alias_or_name, table)) 403 for _expr in itertools.chain(self.derived_tables, self.udtfs): 404 # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable trait vtable dispatch 405 expression: exp.Expr = _expr 406 self._references.append( 407 ( 408 _get_source_alias(expression), 409 ( 410 expression if expression.args.get("pivots") else expression.unnest() 411 ).assert_is(exp.Selectable), 412 ) 413 ) 414 415 return self._references
417 @property 418 def external_columns(self) -> list[exp.Column]: 419 """ 420 Columns that appear to reference sources in outer scopes. 421 422 Returns: 423 list[exp.Column]: Column instances that don't reference sources in the current scope. 424 """ 425 if self._external_columns is None: 426 if isinstance(self.expression, exp.SetOperation): 427 left, right = self.union_scopes 428 self._external_columns = left.external_columns + right.external_columns 429 else: 430 local_source_names = {name for name, _ in self.references} 431 self._external_columns = [ 432 c 433 for c in self.columns 434 if c.text("table") not in local_source_names 435 and c.text("table") not in self.semi_or_anti_join_tables 436 ] 437 438 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
440 @property 441 def local_columns(self) -> list[exp.Column]: 442 """ 443 Columns in this scope that are not external. 444 445 Returns: 446 list[exp.Column]: Column instances that reference sources in the current scope. 447 """ 448 if self._local_columns is None: 449 external_columns = set(self.external_columns) 450 self._local_columns = [c for c in self.columns if c not in external_columns] 451 452 return self._local_columns
Columns in this scope that are not external.
Returns:
list[exp.Column]: Column instances that reference sources in the current scope.
454 @property 455 def unqualified_columns(self) -> list[exp.Column]: 456 """ 457 Unqualified columns in the current scope. 458 459 Returns: 460 list[exp.Column]: Unqualified columns 461 """ 462 return [c for c in self.columns if not c.text("table")]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
464 @property 465 def join_hints(self) -> list[exp.JoinHint]: 466 """ 467 Hints that exist in the scope that reference tables 468 469 Returns: 470 list[exp.JoinHint]: Join hints that are referenced within the scope 471 """ 472 self._ensure_collected() 473 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
489 def source_columns(self, source_name: str) -> list[exp.Column]: 490 """ 491 Get all columns in the current scope for a particular source. 492 493 Args: 494 source_name (str): Name of the source 495 Returns: 496 list[exp.Column]: Column instances that reference `source_name` 497 """ 498 return [column for column in self.columns if column.text("table") == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
500 @property 501 def is_subquery(self) -> bool: 502 """Determine if this scope is a subquery""" 503 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
505 @property 506 def is_derived_table(self) -> bool: 507 """Determine if this scope is a derived table""" 508 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
510 @property 511 def is_union(self) -> bool: 512 """Determine if this scope is a union""" 513 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
515 @property 516 def is_cte(self) -> bool: 517 """Determine if this scope is a common table expression""" 518 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
520 @property 521 def is_root(self) -> bool: 522 """Determine if this is the root scope""" 523 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
525 @property 526 def is_udtf(self) -> bool: 527 """Determine if this scope is a UDTF (User Defined Table Function)""" 528 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
535 def rename_source(self, old_name: str | None, new_name: str) -> None: 536 """Rename a source in this scope""" 537 old_name = old_name or "" 538 if old_name in self.sources: 539 self.sources[new_name] = self.sources.pop(old_name)
Rename a source in this scope
541 def add_source(self, name: str, source: exp.Table | Scope) -> None: 542 """Add a source to this scope""" 543 self.sources[name] = source 544 self.clear_cache()
Add a source to this scope
546 def remove_source(self, name: str) -> None: 547 """Remove a source from this scope""" 548 self.sources.pop(name, None) 549 self.clear_cache()
Remove a source from this scope
554 def traverse(self) -> Iterator[Scope]: 555 """ 556 Traverse the scope tree from this node. 557 558 Yields: 559 Scope: scope instances in depth-first-search post-order 560 """ 561 stack: list[Scope] = [self] 562 result: list[Scope] = [] 563 while stack: 564 scope = stack.pop() 565 result.append(scope) 566 stack.extend( 567 itertools.chain( 568 scope.cte_scopes, 569 scope.union_scopes, 570 scope.table_scopes, 571 scope.subquery_scopes, 572 ) 573 ) 574 575 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
577 def ref_count(self) -> dict[int, int]: 578 """ 579 Count the number of times each scope in this tree is referenced. 580 581 Returns: 582 dict[int, int]: Mapping of Scope instance ID to reference count 583 """ 584 scope_ref_count: dict[int, int] = defaultdict(int) 585 586 for scope in self.traverse(): 587 for _, source in scope.selected_sources.values(): 588 scope_ref_count[id(source)] += 1 589 590 for name in scope._semi_anti_join_tables: 591 # semi/anti join sources are not actually selected but we still need to 592 # increment their ref count to avoid them being optimized away 593 if name in scope.sources: 594 scope_ref_count[id(scope.sources[name])] += 1 595 596 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
599def traverse_scope(expression: exp.Expr) -> list[Scope]: 600 """ 601 Traverse an expression by its "scopes". 602 603 "Scope" represents the current context of a Select statement. 604 605 This is helpful for optimizing queries, where we need more information than 606 the expression tree itself. For example, we might care about the source 607 names within a subquery. Returns a list because a generator could result in 608 incomplete properties which is confusing. 609 610 Examples: 611 >>> import sqlglot 612 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 613 >>> scopes = traverse_scope(expression) 614 >>> scopes[0].expression.sql(), list(scopes[0].sources) 615 ('SELECT a FROM x', ['x']) 616 >>> scopes[1].expression.sql(), list(scopes[1].sources) 617 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 618 619 Args: 620 expression: Expr to traverse 621 622 Returns: 623 A list of the created scope instances 624 """ 625 if isinstance(expression, TRAVERSABLES): 626 return list(_traverse_scope(Scope(expression))) 627 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expr to traverse
Returns:
A list of the created scope instances
630def build_scope(expression: exp.Expr) -> Scope | None: 631 """ 632 Build a scope tree. 633 634 Args: 635 expression: Expr to build the scope tree for. 636 637 Returns: 638 The root scope 639 """ 640 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expr to build the scope tree for.
Returns:
The root scope
915def walk_in_scope( 916 expression: exp.Expr, 917 prune: t.Callable[[exp.Expr], bool] | None = None, 918) -> Iterator[exp.Expr]: 919 """ 920 Returns a generator object which visits all nodes in the syntrax tree, stopping at 921 nodes that start child scopes. This does a custom DFS traversal rather than using 922 expression.walk() because the nested generators aren't optimized in mypyc. 923 924 Args: 925 expression: 926 prune: callable that returns True if 927 the generator should stop traversing this branch of the tree. 928 929 Yields: 930 exp.Expr: each node in scope 931 """ 932 stack: list[exp.Expr] = [expression] 933 934 while stack: 935 node = stack.pop() 936 937 yield node 938 939 if node is not expression and ( 940 isinstance(node, exp.CTE) 941 or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node)) 942 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 943 or isinstance(node, exp.UNWRAPPED_QUERIES) 944 ): 945 if isinstance(node, (exp.Subquery, exp.UDTF)): 946 for key in ("joins", "laterals", "pivots"): 947 for arg in node.args.get(key) or []: 948 yield from walk_in_scope(arg) 949 continue 950 951 if prune and prune(node): 952 continue 953 954 for vs in reversed(node.args.values()): 955 if isinstance(vs, list): 956 for v in reversed(vs): 957 if isinstance(v, exp.Expr): 958 stack.append(v) 959 elif isinstance(vs, exp.Expr): 960 stack.append(vs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes. This does a custom DFS traversal rather than using expression.walk() because the nested generators aren't optimized in mypyc.
Arguments:
- expression:
- prune: callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
exp.Expr: each node in scope
963def find_all_in_scope( 964 expression: exp.Expr, 965 *expression_types: Type[E], 966) -> Iterator[E]: 967 """ 968 Returns a generator object which visits all nodes in this scope and only yields those that 969 match at least one of the specified expression types. 970 971 This does NOT traverse into subscopes. 972 973 Args: 974 expression: the expression to search. 975 expression_types: the expression type(s) to match. 976 977 Yields: 978 The matching nodes. 979 """ 980 for node in walk_in_scope(expression): 981 if isinstance(node, expression_types): 982 yield node
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression: the expression to search.
- expression_types: the expression type(s) to match.
Yields:
The matching nodes.
985def find_in_scope( 986 expression: exp.Expr, 987 *expression_types: Type[E], 988) -> E | None: 989 """ 990 Returns the first node in this scope which matches at least one of the specified types. 991 992 This does NOT traverse into subscopes. 993 994 Args: 995 expression: the expression to search. 996 expression_types: the expression type(s) to match. 997 998 Returns: 999 The node which matches the criteria or None if no node matching 1000 the criteria was found. 1001 """ 1002 return next(find_all_in_scope(expression, *expression_types), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression: the expression to search.
- expression_types: the expression type(s) to match.
Returns:
The node which matches the criteria or None if no node matching the criteria was found.