sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._stars = None 92 self._derived_tables = None 93 self._udtfs = None 94 self._tables = None 95 self._ctes = None 96 self._subqueries = None 97 self._selected_sources = None 98 self._columns = None 99 self._external_columns = None 100 self._join_hints = None 101 self._pivots = None 102 self._references = None 103 self._semi_anti_join_tables = None 104 105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 ) 120 121 def _collect(self): 122 self._tables = [] 123 self._ctes = [] 124 self._subqueries = [] 125 self._derived_tables = [] 126 self._udtfs = [] 127 self._raw_columns = [] 128 self._stars = [] 129 self._join_hints = [] 130 self._semi_anti_join_tables = set() 131 132 for node in self.walk(bfs=False): 133 if node is self.expression: 134 continue 135 136 if isinstance(node, exp.Dot) and node.is_star: 137 self._stars.append(node) 138 elif isinstance(node, exp.Column): 139 if isinstance(node.this, exp.Star): 140 self._stars.append(node) 141 else: 142 self._raw_columns.append(node) 143 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 144 parent = node.parent 145 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 146 self._semi_anti_join_tables.add(node.alias_or_name) 147 148 self._tables.append(node) 149 elif isinstance(node, exp.JoinHint): 150 self._join_hints.append(node) 151 elif isinstance(node, exp.UDTF): 152 self._udtfs.append(node) 153 elif isinstance(node, exp.CTE): 154 self._ctes.append(node) 155 elif _is_derived_table(node) and _is_from_or_join(node): 156 self._derived_tables.append(node) 157 elif isinstance(node, exp.UNWRAPPED_QUERIES): 158 self._subqueries.append(node) 159 160 self._collected = True 161 162 def _ensure_collected(self): 163 if not self._collected: 164 self._collect() 165 166 def walk(self, bfs=True, prune=None): 167 return walk_in_scope(self.expression, bfs=bfs, prune=None) 168 169 def find(self, *expression_types, bfs=True): 170 return find_in_scope(self.expression, expression_types, bfs=bfs) 171 172 def find_all(self, *expression_types, bfs=True): 173 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 174 175 def replace(self, old, new): 176 """ 177 Replace `old` with `new`. 178 179 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 180 181 Args: 182 old (exp.Expression): old node 183 new (exp.Expression): new node 184 """ 185 old.replace(new) 186 self.clear_cache() 187 188 @property 189 def tables(self): 190 """ 191 List of tables in this scope. 192 193 Returns: 194 list[exp.Table]: tables 195 """ 196 self._ensure_collected() 197 return self._tables 198 199 @property 200 def ctes(self): 201 """ 202 List of CTEs in this scope. 203 204 Returns: 205 list[exp.CTE]: ctes 206 """ 207 self._ensure_collected() 208 return self._ctes 209 210 @property 211 def derived_tables(self): 212 """ 213 List of derived tables in this scope. 214 215 For example: 216 SELECT * FROM (SELECT ...) <- that's a derived table 217 218 Returns: 219 list[exp.Subquery]: derived tables 220 """ 221 self._ensure_collected() 222 return self._derived_tables 223 224 @property 225 def udtfs(self): 226 """ 227 List of "User Defined Tabular Functions" in this scope. 228 229 Returns: 230 list[exp.UDTF]: UDTFs 231 """ 232 self._ensure_collected() 233 return self._udtfs 234 235 @property 236 def subqueries(self): 237 """ 238 List of subqueries in this scope. 239 240 For example: 241 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 242 243 Returns: 244 list[exp.Select | exp.SetOperation]: subqueries 245 """ 246 self._ensure_collected() 247 return self._subqueries 248 249 @property 250 def stars(self) -> t.List[exp.Column | exp.Dot]: 251 """ 252 List of star expressions (columns or dots) in this scope. 253 """ 254 self._ensure_collected() 255 return self._stars 256 257 @property 258 def columns(self): 259 """ 260 List of columns in this scope. 261 262 Returns: 263 list[exp.Column]: Column instances in this scope, plus any 264 Columns that reference this scope from correlated subqueries. 265 """ 266 if self._columns is None: 267 self._ensure_collected() 268 columns = self._raw_columns 269 270 external_columns = [ 271 column 272 for scope in itertools.chain( 273 self.subquery_scopes, 274 self.udtf_scopes, 275 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 276 ) 277 for column in scope.external_columns 278 ] 279 280 named_selects = set(self.expression.named_selects) 281 282 self._columns = [] 283 for column in columns + external_columns: 284 ancestor = column.find_ancestor( 285 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 286 ) 287 if ( 288 not ancestor 289 or column.table 290 or isinstance(ancestor, exp.Select) 291 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 292 or ( 293 isinstance(ancestor, exp.Order) 294 and ( 295 isinstance(ancestor.parent, exp.Window) 296 or column.name not in named_selects 297 ) 298 ) 299 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 300 ): 301 self._columns.append(column) 302 303 return self._columns 304 305 @property 306 def selected_sources(self): 307 """ 308 Mapping of nodes and sources that are actually selected from in this scope. 309 310 That is, all tables in a schema are selectable at any point. But a 311 table only becomes a selected source if it's included in a FROM or JOIN clause. 312 313 Returns: 314 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 315 """ 316 if self._selected_sources is None: 317 result = {} 318 319 for name, node in self.references: 320 if name in self._semi_anti_join_tables: 321 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 322 # selected source 323 continue 324 325 if name in result: 326 raise OptimizeError(f"Alias already used: {name}") 327 if name in self.sources: 328 result[name] = (node, self.sources[name]) 329 330 self._selected_sources = result 331 return self._selected_sources 332 333 @property 334 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 335 if self._references is None: 336 self._references = [] 337 338 for table in self.tables: 339 self._references.append((table.alias_or_name, table)) 340 for expression in itertools.chain(self.derived_tables, self.udtfs): 341 self._references.append( 342 ( 343 expression.alias, 344 expression if expression.args.get("pivots") else expression.unnest(), 345 ) 346 ) 347 348 return self._references 349 350 @property 351 def external_columns(self): 352 """ 353 Columns that appear to reference sources in outer scopes. 354 355 Returns: 356 list[exp.Column]: Column instances that don't reference 357 sources in the current scope. 358 """ 359 if self._external_columns is None: 360 if isinstance(self.expression, exp.SetOperation): 361 left, right = self.union_scopes 362 self._external_columns = left.external_columns + right.external_columns 363 else: 364 self._external_columns = [ 365 c 366 for c in self.columns 367 if c.table not in self.selected_sources 368 and c.table not in self.semi_or_anti_join_tables 369 ] 370 371 return self._external_columns 372 373 @property 374 def unqualified_columns(self): 375 """ 376 Unqualified columns in the current scope. 377 378 Returns: 379 list[exp.Column]: Unqualified columns 380 """ 381 return [c for c in self.columns if not c.table] 382 383 @property 384 def join_hints(self): 385 """ 386 Hints that exist in the scope that reference tables 387 388 Returns: 389 list[exp.JoinHint]: Join hints that are referenced within the scope 390 """ 391 if self._join_hints is None: 392 return [] 393 return self._join_hints 394 395 @property 396 def pivots(self): 397 if not self._pivots: 398 self._pivots = [ 399 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 400 ] 401 402 return self._pivots 403 404 @property 405 def semi_or_anti_join_tables(self): 406 return self._semi_anti_join_tables or set() 407 408 def source_columns(self, source_name): 409 """ 410 Get all columns in the current scope for a particular source. 411 412 Args: 413 source_name (str): Name of the source 414 Returns: 415 list[exp.Column]: Column instances that reference `source_name` 416 """ 417 return [column for column in self.columns if column.table == source_name] 418 419 @property 420 def is_subquery(self): 421 """Determine if this scope is a subquery""" 422 return self.scope_type == ScopeType.SUBQUERY 423 424 @property 425 def is_derived_table(self): 426 """Determine if this scope is a derived table""" 427 return self.scope_type == ScopeType.DERIVED_TABLE 428 429 @property 430 def is_union(self): 431 """Determine if this scope is a union""" 432 return self.scope_type == ScopeType.UNION 433 434 @property 435 def is_cte(self): 436 """Determine if this scope is a common table expression""" 437 return self.scope_type == ScopeType.CTE 438 439 @property 440 def is_root(self): 441 """Determine if this is the root scope""" 442 return self.scope_type == ScopeType.ROOT 443 444 @property 445 def is_udtf(self): 446 """Determine if this scope is a UDTF (User Defined Table Function)""" 447 return self.scope_type == ScopeType.UDTF 448 449 @property 450 def is_correlated_subquery(self): 451 """Determine if this scope is a correlated subquery""" 452 return bool(self.can_be_correlated and self.external_columns) 453 454 def rename_source(self, old_name, new_name): 455 """Rename a source in this scope""" 456 columns = self.sources.pop(old_name or "", []) 457 self.sources[new_name] = columns 458 459 def add_source(self, name, source): 460 """Add a source to this scope""" 461 self.sources[name] = source 462 self.clear_cache() 463 464 def remove_source(self, name): 465 """Remove a source from this scope""" 466 self.sources.pop(name, None) 467 self.clear_cache() 468 469 def __repr__(self): 470 return f"Scope<{self.expression.sql()}>" 471 472 def traverse(self): 473 """ 474 Traverse the scope tree from this node. 475 476 Yields: 477 Scope: scope instances in depth-first-search post-order 478 """ 479 stack = [self] 480 result = [] 481 while stack: 482 scope = stack.pop() 483 result.append(scope) 484 stack.extend( 485 itertools.chain( 486 scope.cte_scopes, 487 scope.union_scopes, 488 scope.table_scopes, 489 scope.subquery_scopes, 490 ) 491 ) 492 493 yield from reversed(result) 494 495 def ref_count(self): 496 """ 497 Count the number of times each scope in this tree is referenced. 498 499 Returns: 500 dict[int, int]: Mapping of Scope instance ID to reference count 501 """ 502 scope_ref_count = defaultdict(lambda: 0) 503 504 for scope in self.traverse(): 505 for _, source in scope.selected_sources.values(): 506 scope_ref_count[id(source)] += 1 507 508 return scope_ref_count 509 510 511def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 512 """ 513 Traverse an expression by its "scopes". 514 515 "Scope" represents the current context of a Select statement. 516 517 This is helpful for optimizing queries, where we need more information than 518 the expression tree itself. For example, we might care about the source 519 names within a subquery. Returns a list because a generator could result in 520 incomplete properties which is confusing. 521 522 Examples: 523 >>> import sqlglot 524 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 525 >>> scopes = traverse_scope(expression) 526 >>> scopes[0].expression.sql(), list(scopes[0].sources) 527 ('SELECT a FROM x', ['x']) 528 >>> scopes[1].expression.sql(), list(scopes[1].sources) 529 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 530 531 Args: 532 expression: Expression to traverse 533 534 Returns: 535 A list of the created scope instances 536 """ 537 if isinstance(expression, TRAVERSABLES): 538 return list(_traverse_scope(Scope(expression))) 539 return [] 540 541 542def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 543 """ 544 Build a scope tree. 545 546 Args: 547 expression: Expression to build the scope tree for. 548 549 Returns: 550 The root scope 551 """ 552 return seq_get(traverse_scope(expression), -1) 553 554 555def _traverse_scope(scope): 556 expression = scope.expression 557 558 if isinstance(expression, exp.Select): 559 yield from _traverse_select(scope) 560 elif isinstance(expression, exp.SetOperation): 561 yield from _traverse_ctes(scope) 562 yield from _traverse_union(scope) 563 return 564 elif isinstance(expression, exp.Subquery): 565 if scope.is_root: 566 yield from _traverse_select(scope) 567 else: 568 yield from _traverse_subqueries(scope) 569 elif isinstance(expression, exp.Table): 570 yield from _traverse_tables(scope) 571 elif isinstance(expression, exp.UDTF): 572 yield from _traverse_udtfs(scope) 573 elif isinstance(expression, exp.DDL): 574 if isinstance(expression.expression, exp.Query): 575 yield from _traverse_ctes(scope) 576 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 577 return 578 elif isinstance(expression, exp.DML): 579 yield from _traverse_ctes(scope) 580 for query in find_all_in_scope(expression, exp.Query): 581 # This check ensures we don't yield the CTE/nested queries twice 582 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 583 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 584 return 585 else: 586 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 587 return 588 589 yield scope 590 591 592def _traverse_select(scope): 593 yield from _traverse_ctes(scope) 594 yield from _traverse_tables(scope) 595 yield from _traverse_subqueries(scope) 596 597 598def _traverse_union(scope): 599 prev_scope = None 600 union_scope_stack = [scope] 601 expression_stack = [scope.expression.right, scope.expression.left] 602 603 while expression_stack: 604 expression = expression_stack.pop() 605 union_scope = union_scope_stack[-1] 606 607 new_scope = union_scope.branch( 608 expression, 609 outer_columns=union_scope.outer_columns, 610 scope_type=ScopeType.UNION, 611 ) 612 613 if isinstance(expression, exp.SetOperation): 614 yield from _traverse_ctes(new_scope) 615 616 union_scope_stack.append(new_scope) 617 expression_stack.extend([expression.right, expression.left]) 618 continue 619 620 for scope in _traverse_scope(new_scope): 621 yield scope 622 623 if prev_scope: 624 union_scope_stack.pop() 625 union_scope.union_scopes = [prev_scope, scope] 626 prev_scope = union_scope 627 628 yield union_scope 629 else: 630 prev_scope = scope 631 632 633def _traverse_ctes(scope): 634 sources = {} 635 636 for cte in scope.ctes: 637 cte_name = cte.alias 638 639 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 640 # thus the recursive scope is the first section of the union. 641 with_ = scope.expression.args.get("with") 642 if with_ and with_.recursive: 643 union = cte.this 644 645 if isinstance(union, exp.SetOperation): 646 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 647 648 child_scope = None 649 650 for child_scope in _traverse_scope( 651 scope.branch( 652 cte.this, 653 cte_sources=sources, 654 outer_columns=cte.alias_column_names, 655 scope_type=ScopeType.CTE, 656 ) 657 ): 658 yield child_scope 659 660 # append the final child_scope yielded 661 if child_scope: 662 sources[cte_name] = child_scope 663 scope.cte_scopes.append(child_scope) 664 665 scope.sources.update(sources) 666 scope.cte_sources.update(sources) 667 668 669def _is_derived_table(expression: exp.Subquery) -> bool: 670 """ 671 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 672 as it doesn't introduce a new scope. If an alias is present, it shadows all names 673 under the Subquery, so that's one exception to this rule. 674 """ 675 return isinstance(expression, exp.Subquery) and bool( 676 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 677 ) 678 679 680def _is_from_or_join(expression: exp.Expression) -> bool: 681 """ 682 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 683 """ 684 parent = expression.parent 685 686 # Subqueries can be arbitrarily nested 687 while isinstance(parent, exp.Subquery): 688 parent = parent.parent 689 690 return isinstance(parent, (exp.From, exp.Join)) 691 692 693def _traverse_tables(scope): 694 sources = {} 695 696 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 697 expressions = [] 698 from_ = scope.expression.args.get("from") 699 if from_: 700 expressions.append(from_.this) 701 702 for join in scope.expression.args.get("joins") or []: 703 expressions.append(join.this) 704 705 if isinstance(scope.expression, exp.Table): 706 expressions.append(scope.expression) 707 708 expressions.extend(scope.expression.args.get("laterals") or []) 709 710 for expression in expressions: 711 if isinstance(expression, exp.Final): 712 expression = expression.this 713 if isinstance(expression, exp.Table): 714 table_name = expression.name 715 source_name = expression.alias_or_name 716 717 if table_name in scope.sources and not expression.db: 718 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 719 # it is pivoted, because then we get back a new table and hence a new source. 720 pivots = expression.args.get("pivots") 721 if pivots: 722 sources[pivots[0].alias] = expression 723 else: 724 sources[source_name] = scope.sources[table_name] 725 elif source_name in sources: 726 sources[find_new_name(sources, table_name)] = expression 727 else: 728 sources[source_name] = expression 729 730 # Make sure to not include the joins twice 731 if expression is not scope.expression: 732 expressions.extend(join.this for join in expression.args.get("joins") or []) 733 734 continue 735 736 if not isinstance(expression, exp.DerivedTable): 737 continue 738 739 if isinstance(expression, exp.UDTF): 740 lateral_sources = sources 741 scope_type = ScopeType.UDTF 742 scopes = scope.udtf_scopes 743 elif _is_derived_table(expression): 744 lateral_sources = None 745 scope_type = ScopeType.DERIVED_TABLE 746 scopes = scope.derived_table_scopes 747 expressions.extend(join.this for join in expression.args.get("joins") or []) 748 else: 749 # Makes sure we check for possible sources in nested table constructs 750 expressions.append(expression.this) 751 expressions.extend(join.this for join in expression.args.get("joins") or []) 752 continue 753 754 for child_scope in _traverse_scope( 755 scope.branch( 756 expression, 757 lateral_sources=lateral_sources, 758 outer_columns=expression.alias_column_names, 759 scope_type=scope_type, 760 ) 761 ): 762 yield child_scope 763 764 # Tables without aliases will be set as "" 765 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 766 # Until then, this means that only a single, unaliased derived table is allowed (rather, 767 # the latest one wins. 768 sources[expression.alias] = child_scope 769 770 # append the final child_scope yielded 771 scopes.append(child_scope) 772 scope.table_scopes.append(child_scope) 773 774 scope.sources.update(sources) 775 776 777def _traverse_subqueries(scope): 778 for subquery in scope.subqueries: 779 top = None 780 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 781 yield child_scope 782 top = child_scope 783 scope.subquery_scopes.append(top) 784 785 786def _traverse_udtfs(scope): 787 if isinstance(scope.expression, exp.Unnest): 788 expressions = scope.expression.expressions 789 elif isinstance(scope.expression, exp.Lateral): 790 expressions = [scope.expression.this] 791 else: 792 expressions = [] 793 794 sources = {} 795 for expression in expressions: 796 if _is_derived_table(expression): 797 top = None 798 for child_scope in _traverse_scope( 799 scope.branch( 800 expression, 801 scope_type=ScopeType.SUBQUERY, 802 outer_columns=expression.alias_column_names, 803 ) 804 ): 805 yield child_scope 806 top = child_scope 807 sources[expression.alias] = child_scope 808 809 scope.subquery_scopes.append(top) 810 811 scope.sources.update(sources) 812 813 814def walk_in_scope(expression, bfs=True, prune=None): 815 """ 816 Returns a generator object which visits all nodes in the syntrax tree, stopping at 817 nodes that start child scopes. 818 819 Args: 820 expression (exp.Expression): 821 bfs (bool): if set to True the BFS traversal order will be applied, 822 otherwise the DFS traversal will be used instead. 823 prune ((node, parent, arg_key) -> bool): callable that returns True if 824 the generator should stop traversing this branch of the tree. 825 826 Yields: 827 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 828 """ 829 # We'll use this variable to pass state into the dfs generator. 830 # Whenever we set it to True, we exclude a subtree from traversal. 831 crossed_scope_boundary = False 832 833 for node in expression.walk( 834 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 835 ): 836 crossed_scope_boundary = False 837 838 yield node 839 840 if node is expression: 841 continue 842 if ( 843 isinstance(node, exp.CTE) 844 or ( 845 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 846 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 847 ) 848 or isinstance(node, exp.UNWRAPPED_QUERIES) 849 ): 850 crossed_scope_boundary = True 851 852 if isinstance(node, (exp.Subquery, exp.UDTF)): 853 # The following args are not actually in the inner scope, so we should visit them 854 for key in ("joins", "laterals", "pivots"): 855 for arg in node.args.get(key) or []: 856 yield from walk_in_scope(arg, bfs=bfs) 857 858 859def find_all_in_scope(expression, expression_types, bfs=True): 860 """ 861 Returns a generator object which visits all nodes in this scope and only yields those that 862 match at least one of the specified expression types. 863 864 This does NOT traverse into subscopes. 865 866 Args: 867 expression (exp.Expression): 868 expression_types (tuple[type]|type): the expression type(s) to match. 869 bfs (bool): True to use breadth-first search, False to use depth-first. 870 871 Yields: 872 exp.Expression: nodes 873 """ 874 for expression in walk_in_scope(expression, bfs=bfs): 875 if isinstance(expression, tuple(ensure_collection(expression_types))): 876 yield expression 877 878 879def find_in_scope(expression, expression_types, bfs=True): 880 """ 881 Returns the first node in this scope which matches at least one of the specified types. 882 883 This does NOT traverse into subscopes. 884 885 Args: 886 expression (exp.Expression): 887 expression_types (tuple[type]|type): the expression type(s) to match. 888 bfs (bool): True to use breadth-first search, False to use depth-first. 889 890 Returns: 891 exp.Expression: the node which matches the criteria or None if no node matching 892 the criteria was found. 893 """ 894 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 self._semi_anti_join_tables = None 105 106 def branch( 107 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 108 ): 109 """Branch from the current scope to a new, inner scope""" 110 return Scope( 111 expression=expression.unnest(), 112 sources=sources.copy() if sources else None, 113 parent=self, 114 scope_type=scope_type, 115 cte_sources={**self.cte_sources, **(cte_sources or {})}, 116 lateral_sources=lateral_sources.copy() if lateral_sources else None, 117 can_be_correlated=self.can_be_correlated 118 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 119 **kwargs, 120 ) 121 122 def _collect(self): 123 self._tables = [] 124 self._ctes = [] 125 self._subqueries = [] 126 self._derived_tables = [] 127 self._udtfs = [] 128 self._raw_columns = [] 129 self._stars = [] 130 self._join_hints = [] 131 self._semi_anti_join_tables = set() 132 133 for node in self.walk(bfs=False): 134 if node is self.expression: 135 continue 136 137 if isinstance(node, exp.Dot) and node.is_star: 138 self._stars.append(node) 139 elif isinstance(node, exp.Column): 140 if isinstance(node.this, exp.Star): 141 self._stars.append(node) 142 else: 143 self._raw_columns.append(node) 144 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 145 parent = node.parent 146 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 147 self._semi_anti_join_tables.add(node.alias_or_name) 148 149 self._tables.append(node) 150 elif isinstance(node, exp.JoinHint): 151 self._join_hints.append(node) 152 elif isinstance(node, exp.UDTF): 153 self._udtfs.append(node) 154 elif isinstance(node, exp.CTE): 155 self._ctes.append(node) 156 elif _is_derived_table(node) and _is_from_or_join(node): 157 self._derived_tables.append(node) 158 elif isinstance(node, exp.UNWRAPPED_QUERIES): 159 self._subqueries.append(node) 160 161 self._collected = True 162 163 def _ensure_collected(self): 164 if not self._collected: 165 self._collect() 166 167 def walk(self, bfs=True, prune=None): 168 return walk_in_scope(self.expression, bfs=bfs, prune=None) 169 170 def find(self, *expression_types, bfs=True): 171 return find_in_scope(self.expression, expression_types, bfs=bfs) 172 173 def find_all(self, *expression_types, bfs=True): 174 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 175 176 def replace(self, old, new): 177 """ 178 Replace `old` with `new`. 179 180 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 181 182 Args: 183 old (exp.Expression): old node 184 new (exp.Expression): new node 185 """ 186 old.replace(new) 187 self.clear_cache() 188 189 @property 190 def tables(self): 191 """ 192 List of tables in this scope. 193 194 Returns: 195 list[exp.Table]: tables 196 """ 197 self._ensure_collected() 198 return self._tables 199 200 @property 201 def ctes(self): 202 """ 203 List of CTEs in this scope. 204 205 Returns: 206 list[exp.CTE]: ctes 207 """ 208 self._ensure_collected() 209 return self._ctes 210 211 @property 212 def derived_tables(self): 213 """ 214 List of derived tables in this scope. 215 216 For example: 217 SELECT * FROM (SELECT ...) <- that's a derived table 218 219 Returns: 220 list[exp.Subquery]: derived tables 221 """ 222 self._ensure_collected() 223 return self._derived_tables 224 225 @property 226 def udtfs(self): 227 """ 228 List of "User Defined Tabular Functions" in this scope. 229 230 Returns: 231 list[exp.UDTF]: UDTFs 232 """ 233 self._ensure_collected() 234 return self._udtfs 235 236 @property 237 def subqueries(self): 238 """ 239 List of subqueries in this scope. 240 241 For example: 242 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 243 244 Returns: 245 list[exp.Select | exp.SetOperation]: subqueries 246 """ 247 self._ensure_collected() 248 return self._subqueries 249 250 @property 251 def stars(self) -> t.List[exp.Column | exp.Dot]: 252 """ 253 List of star expressions (columns or dots) in this scope. 254 """ 255 self._ensure_collected() 256 return self._stars 257 258 @property 259 def columns(self): 260 """ 261 List of columns in this scope. 262 263 Returns: 264 list[exp.Column]: Column instances in this scope, plus any 265 Columns that reference this scope from correlated subqueries. 266 """ 267 if self._columns is None: 268 self._ensure_collected() 269 columns = self._raw_columns 270 271 external_columns = [ 272 column 273 for scope in itertools.chain( 274 self.subquery_scopes, 275 self.udtf_scopes, 276 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 277 ) 278 for column in scope.external_columns 279 ] 280 281 named_selects = set(self.expression.named_selects) 282 283 self._columns = [] 284 for column in columns + external_columns: 285 ancestor = column.find_ancestor( 286 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 287 ) 288 if ( 289 not ancestor 290 or column.table 291 or isinstance(ancestor, exp.Select) 292 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 293 or ( 294 isinstance(ancestor, exp.Order) 295 and ( 296 isinstance(ancestor.parent, exp.Window) 297 or column.name not in named_selects 298 ) 299 ) 300 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 301 ): 302 self._columns.append(column) 303 304 return self._columns 305 306 @property 307 def selected_sources(self): 308 """ 309 Mapping of nodes and sources that are actually selected from in this scope. 310 311 That is, all tables in a schema are selectable at any point. But a 312 table only becomes a selected source if it's included in a FROM or JOIN clause. 313 314 Returns: 315 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 316 """ 317 if self._selected_sources is None: 318 result = {} 319 320 for name, node in self.references: 321 if name in self._semi_anti_join_tables: 322 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 323 # selected source 324 continue 325 326 if name in result: 327 raise OptimizeError(f"Alias already used: {name}") 328 if name in self.sources: 329 result[name] = (node, self.sources[name]) 330 331 self._selected_sources = result 332 return self._selected_sources 333 334 @property 335 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 336 if self._references is None: 337 self._references = [] 338 339 for table in self.tables: 340 self._references.append((table.alias_or_name, table)) 341 for expression in itertools.chain(self.derived_tables, self.udtfs): 342 self._references.append( 343 ( 344 expression.alias, 345 expression if expression.args.get("pivots") else expression.unnest(), 346 ) 347 ) 348 349 return self._references 350 351 @property 352 def external_columns(self): 353 """ 354 Columns that appear to reference sources in outer scopes. 355 356 Returns: 357 list[exp.Column]: Column instances that don't reference 358 sources in the current scope. 359 """ 360 if self._external_columns is None: 361 if isinstance(self.expression, exp.SetOperation): 362 left, right = self.union_scopes 363 self._external_columns = left.external_columns + right.external_columns 364 else: 365 self._external_columns = [ 366 c 367 for c in self.columns 368 if c.table not in self.selected_sources 369 and c.table not in self.semi_or_anti_join_tables 370 ] 371 372 return self._external_columns 373 374 @property 375 def unqualified_columns(self): 376 """ 377 Unqualified columns in the current scope. 378 379 Returns: 380 list[exp.Column]: Unqualified columns 381 """ 382 return [c for c in self.columns if not c.table] 383 384 @property 385 def join_hints(self): 386 """ 387 Hints that exist in the scope that reference tables 388 389 Returns: 390 list[exp.JoinHint]: Join hints that are referenced within the scope 391 """ 392 if self._join_hints is None: 393 return [] 394 return self._join_hints 395 396 @property 397 def pivots(self): 398 if not self._pivots: 399 self._pivots = [ 400 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 401 ] 402 403 return self._pivots 404 405 @property 406 def semi_or_anti_join_tables(self): 407 return self._semi_anti_join_tables or set() 408 409 def source_columns(self, source_name): 410 """ 411 Get all columns in the current scope for a particular source. 412 413 Args: 414 source_name (str): Name of the source 415 Returns: 416 list[exp.Column]: Column instances that reference `source_name` 417 """ 418 return [column for column in self.columns if column.table == source_name] 419 420 @property 421 def is_subquery(self): 422 """Determine if this scope is a subquery""" 423 return self.scope_type == ScopeType.SUBQUERY 424 425 @property 426 def is_derived_table(self): 427 """Determine if this scope is a derived table""" 428 return self.scope_type == ScopeType.DERIVED_TABLE 429 430 @property 431 def is_union(self): 432 """Determine if this scope is a union""" 433 return self.scope_type == ScopeType.UNION 434 435 @property 436 def is_cte(self): 437 """Determine if this scope is a common table expression""" 438 return self.scope_type == ScopeType.CTE 439 440 @property 441 def is_root(self): 442 """Determine if this is the root scope""" 443 return self.scope_type == ScopeType.ROOT 444 445 @property 446 def is_udtf(self): 447 """Determine if this scope is a UDTF (User Defined Table Function)""" 448 return self.scope_type == ScopeType.UDTF 449 450 @property 451 def is_correlated_subquery(self): 452 """Determine if this scope is a correlated subquery""" 453 return bool(self.can_be_correlated and self.external_columns) 454 455 def rename_source(self, old_name, new_name): 456 """Rename a source in this scope""" 457 columns = self.sources.pop(old_name or "", []) 458 self.sources[new_name] = columns 459 460 def add_source(self, name, source): 461 """Add a source to this scope""" 462 self.sources[name] = source 463 self.clear_cache() 464 465 def remove_source(self, name): 466 """Remove a source from this scope""" 467 self.sources.pop(name, None) 468 self.clear_cache() 469 470 def __repr__(self): 471 return f"Scope<{self.expression.sql()}>" 472 473 def traverse(self): 474 """ 475 Traverse the scope tree from this node. 476 477 Yields: 478 Scope: scope instances in depth-first-search post-order 479 """ 480 stack = [self] 481 result = [] 482 while stack: 483 scope = stack.pop() 484 result.append(scope) 485 stack.extend( 486 itertools.chain( 487 scope.cte_scopes, 488 scope.union_scopes, 489 scope.table_scopes, 490 scope.subquery_scopes, 491 ) 492 ) 493 494 yield from reversed(result) 495 496 def ref_count(self): 497 """ 498 Count the number of times each scope in this tree is referenced. 499 500 Returns: 501 dict[int, int]: Mapping of Scope instance ID to reference count 502 """ 503 scope_ref_count = defaultdict(lambda: 0) 504 505 for scope in self.traverse(): 506 for _, source in scope.selected_sources.values(): 507 scope_ref_count[id(source)] += 1 508 509 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 self._semi_anti_join_tables = None
106 def branch( 107 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 108 ): 109 """Branch from the current scope to a new, inner scope""" 110 return Scope( 111 expression=expression.unnest(), 112 sources=sources.copy() if sources else None, 113 parent=self, 114 scope_type=scope_type, 115 cte_sources={**self.cte_sources, **(cte_sources or {})}, 116 lateral_sources=lateral_sources.copy() if lateral_sources else None, 117 can_be_correlated=self.can_be_correlated 118 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 119 **kwargs, 120 )
Branch from the current scope to a new, inner scope
176 def replace(self, old, new): 177 """ 178 Replace `old` with `new`. 179 180 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 181 182 Args: 183 old (exp.Expression): old node 184 new (exp.Expression): new node 185 """ 186 old.replace(new) 187 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
189 @property 190 def tables(self): 191 """ 192 List of tables in this scope. 193 194 Returns: 195 list[exp.Table]: tables 196 """ 197 self._ensure_collected() 198 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
200 @property 201 def ctes(self): 202 """ 203 List of CTEs in this scope. 204 205 Returns: 206 list[exp.CTE]: ctes 207 """ 208 self._ensure_collected() 209 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
211 @property 212 def derived_tables(self): 213 """ 214 List of derived tables in this scope. 215 216 For example: 217 SELECT * FROM (SELECT ...) <- that's a derived table 218 219 Returns: 220 list[exp.Subquery]: derived tables 221 """ 222 self._ensure_collected() 223 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
225 @property 226 def udtfs(self): 227 """ 228 List of "User Defined Tabular Functions" in this scope. 229 230 Returns: 231 list[exp.UDTF]: UDTFs 232 """ 233 self._ensure_collected() 234 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
236 @property 237 def subqueries(self): 238 """ 239 List of subqueries in this scope. 240 241 For example: 242 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 243 244 Returns: 245 list[exp.Select | exp.SetOperation]: subqueries 246 """ 247 self._ensure_collected() 248 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
250 @property 251 def stars(self) -> t.List[exp.Column | exp.Dot]: 252 """ 253 List of star expressions (columns or dots) in this scope. 254 """ 255 self._ensure_collected() 256 return self._stars
List of star expressions (columns or dots) in this scope.
258 @property 259 def columns(self): 260 """ 261 List of columns in this scope. 262 263 Returns: 264 list[exp.Column]: Column instances in this scope, plus any 265 Columns that reference this scope from correlated subqueries. 266 """ 267 if self._columns is None: 268 self._ensure_collected() 269 columns = self._raw_columns 270 271 external_columns = [ 272 column 273 for scope in itertools.chain( 274 self.subquery_scopes, 275 self.udtf_scopes, 276 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 277 ) 278 for column in scope.external_columns 279 ] 280 281 named_selects = set(self.expression.named_selects) 282 283 self._columns = [] 284 for column in columns + external_columns: 285 ancestor = column.find_ancestor( 286 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 287 ) 288 if ( 289 not ancestor 290 or column.table 291 or isinstance(ancestor, exp.Select) 292 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 293 or ( 294 isinstance(ancestor, exp.Order) 295 and ( 296 isinstance(ancestor.parent, exp.Window) 297 or column.name not in named_selects 298 ) 299 ) 300 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 301 ): 302 self._columns.append(column) 303 304 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
306 @property 307 def selected_sources(self): 308 """ 309 Mapping of nodes and sources that are actually selected from in this scope. 310 311 That is, all tables in a schema are selectable at any point. But a 312 table only becomes a selected source if it's included in a FROM or JOIN clause. 313 314 Returns: 315 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 316 """ 317 if self._selected_sources is None: 318 result = {} 319 320 for name, node in self.references: 321 if name in self._semi_anti_join_tables: 322 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 323 # selected source 324 continue 325 326 if name in result: 327 raise OptimizeError(f"Alias already used: {name}") 328 if name in self.sources: 329 result[name] = (node, self.sources[name]) 330 331 self._selected_sources = result 332 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
334 @property 335 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 336 if self._references is None: 337 self._references = [] 338 339 for table in self.tables: 340 self._references.append((table.alias_or_name, table)) 341 for expression in itertools.chain(self.derived_tables, self.udtfs): 342 self._references.append( 343 ( 344 expression.alias, 345 expression if expression.args.get("pivots") else expression.unnest(), 346 ) 347 ) 348 349 return self._references
351 @property 352 def external_columns(self): 353 """ 354 Columns that appear to reference sources in outer scopes. 355 356 Returns: 357 list[exp.Column]: Column instances that don't reference 358 sources in the current scope. 359 """ 360 if self._external_columns is None: 361 if isinstance(self.expression, exp.SetOperation): 362 left, right = self.union_scopes 363 self._external_columns = left.external_columns + right.external_columns 364 else: 365 self._external_columns = [ 366 c 367 for c in self.columns 368 if c.table not in self.selected_sources 369 and c.table not in self.semi_or_anti_join_tables 370 ] 371 372 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
374 @property 375 def unqualified_columns(self): 376 """ 377 Unqualified columns in the current scope. 378 379 Returns: 380 list[exp.Column]: Unqualified columns 381 """ 382 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
384 @property 385 def join_hints(self): 386 """ 387 Hints that exist in the scope that reference tables 388 389 Returns: 390 list[exp.JoinHint]: Join hints that are referenced within the scope 391 """ 392 if self._join_hints is None: 393 return [] 394 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
409 def source_columns(self, source_name): 410 """ 411 Get all columns in the current scope for a particular source. 412 413 Args: 414 source_name (str): Name of the source 415 Returns: 416 list[exp.Column]: Column instances that reference `source_name` 417 """ 418 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
420 @property 421 def is_subquery(self): 422 """Determine if this scope is a subquery""" 423 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
425 @property 426 def is_derived_table(self): 427 """Determine if this scope is a derived table""" 428 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
430 @property 431 def is_union(self): 432 """Determine if this scope is a union""" 433 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
435 @property 436 def is_cte(self): 437 """Determine if this scope is a common table expression""" 438 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
440 @property 441 def is_root(self): 442 """Determine if this is the root scope""" 443 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
445 @property 446 def is_udtf(self): 447 """Determine if this scope is a UDTF (User Defined Table Function)""" 448 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
455 def rename_source(self, old_name, new_name): 456 """Rename a source in this scope""" 457 columns = self.sources.pop(old_name or "", []) 458 self.sources[new_name] = columns
Rename a source in this scope
460 def add_source(self, name, source): 461 """Add a source to this scope""" 462 self.sources[name] = source 463 self.clear_cache()
Add a source to this scope
465 def remove_source(self, name): 466 """Remove a source from this scope""" 467 self.sources.pop(name, None) 468 self.clear_cache()
Remove a source from this scope
473 def traverse(self): 474 """ 475 Traverse the scope tree from this node. 476 477 Yields: 478 Scope: scope instances in depth-first-search post-order 479 """ 480 stack = [self] 481 result = [] 482 while stack: 483 scope = stack.pop() 484 result.append(scope) 485 stack.extend( 486 itertools.chain( 487 scope.cte_scopes, 488 scope.union_scopes, 489 scope.table_scopes, 490 scope.subquery_scopes, 491 ) 492 ) 493 494 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
496 def ref_count(self): 497 """ 498 Count the number of times each scope in this tree is referenced. 499 500 Returns: 501 dict[int, int]: Mapping of Scope instance ID to reference count 502 """ 503 scope_ref_count = defaultdict(lambda: 0) 504 505 for scope in self.traverse(): 506 for _, source in scope.selected_sources.values(): 507 scope_ref_count[id(source)] += 1 508 509 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
512def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 513 """ 514 Traverse an expression by its "scopes". 515 516 "Scope" represents the current context of a Select statement. 517 518 This is helpful for optimizing queries, where we need more information than 519 the expression tree itself. For example, we might care about the source 520 names within a subquery. Returns a list because a generator could result in 521 incomplete properties which is confusing. 522 523 Examples: 524 >>> import sqlglot 525 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 526 >>> scopes = traverse_scope(expression) 527 >>> scopes[0].expression.sql(), list(scopes[0].sources) 528 ('SELECT a FROM x', ['x']) 529 >>> scopes[1].expression.sql(), list(scopes[1].sources) 530 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 531 532 Args: 533 expression: Expression to traverse 534 535 Returns: 536 A list of the created scope instances 537 """ 538 if isinstance(expression, TRAVERSABLES): 539 return list(_traverse_scope(Scope(expression))) 540 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
543def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 544 """ 545 Build a scope tree. 546 547 Args: 548 expression: Expression to build the scope tree for. 549 550 Returns: 551 The root scope 552 """ 553 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
815def walk_in_scope(expression, bfs=True, prune=None): 816 """ 817 Returns a generator object which visits all nodes in the syntrax tree, stopping at 818 nodes that start child scopes. 819 820 Args: 821 expression (exp.Expression): 822 bfs (bool): if set to True the BFS traversal order will be applied, 823 otherwise the DFS traversal will be used instead. 824 prune ((node, parent, arg_key) -> bool): callable that returns True if 825 the generator should stop traversing this branch of the tree. 826 827 Yields: 828 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 829 """ 830 # We'll use this variable to pass state into the dfs generator. 831 # Whenever we set it to True, we exclude a subtree from traversal. 832 crossed_scope_boundary = False 833 834 for node in expression.walk( 835 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 836 ): 837 crossed_scope_boundary = False 838 839 yield node 840 841 if node is expression: 842 continue 843 if ( 844 isinstance(node, exp.CTE) 845 or ( 846 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 847 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 848 ) 849 or isinstance(node, exp.UNWRAPPED_QUERIES) 850 ): 851 crossed_scope_boundary = True 852 853 if isinstance(node, (exp.Subquery, exp.UDTF)): 854 # The following args are not actually in the inner scope, so we should visit them 855 for key in ("joins", "laterals", "pivots"): 856 for arg in node.args.get(key) or []: 857 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
860def find_all_in_scope(expression, expression_types, bfs=True): 861 """ 862 Returns a generator object which visits all nodes in this scope and only yields those that 863 match at least one of the specified expression types. 864 865 This does NOT traverse into subscopes. 866 867 Args: 868 expression (exp.Expression): 869 expression_types (tuple[type]|type): the expression type(s) to match. 870 bfs (bool): True to use breadth-first search, False to use depth-first. 871 872 Yields: 873 exp.Expression: nodes 874 """ 875 for expression in walk_in_scope(expression, bfs=bfs): 876 if isinstance(expression, tuple(ensure_collection(expression_types))): 877 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
880def find_in_scope(expression, expression_types, bfs=True): 881 """ 882 Returns the first node in this scope which matches at least one of the specified types. 883 884 This does NOT traverse into subscopes. 885 886 Args: 887 expression (exp.Expression): 888 expression_types (tuple[type]|type): the expression type(s) to match. 889 bfs (bool): True to use breadth-first search, False to use depth-first. 890 891 Returns: 892 exp.Expression: the node which matches the criteria or None if no node matching 893 the criteria was found. 894 """ 895 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.