sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._stars = None 92 self._derived_tables = None 93 self._udtfs = None 94 self._tables = None 95 self._ctes = None 96 self._subqueries = None 97 self._selected_sources = None 98 self._columns = None 99 self._external_columns = None 100 self._join_hints = None 101 self._pivots = None 102 self._references = None 103 104 def branch( 105 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 106 ): 107 """Branch from the current scope to a new, inner scope""" 108 return Scope( 109 expression=expression.unnest(), 110 sources=sources.copy() if sources else None, 111 parent=self, 112 scope_type=scope_type, 113 cte_sources={**self.cte_sources, **(cte_sources or {})}, 114 lateral_sources=lateral_sources.copy() if lateral_sources else None, 115 can_be_correlated=self.can_be_correlated 116 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 117 **kwargs, 118 ) 119 120 def _collect(self): 121 self._tables = [] 122 self._ctes = [] 123 self._subqueries = [] 124 self._derived_tables = [] 125 self._udtfs = [] 126 self._raw_columns = [] 127 self._stars = [] 128 self._join_hints = [] 129 130 for node in self.walk(bfs=False): 131 if node is self.expression: 132 continue 133 134 if isinstance(node, exp.Dot) and node.is_star: 135 self._stars.append(node) 136 elif isinstance(node, exp.Column): 137 if isinstance(node.this, exp.Star): 138 self._stars.append(node) 139 else: 140 self._raw_columns.append(node) 141 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 142 self._tables.append(node) 143 elif isinstance(node, exp.JoinHint): 144 self._join_hints.append(node) 145 elif isinstance(node, exp.UDTF): 146 self._udtfs.append(node) 147 elif isinstance(node, exp.CTE): 148 self._ctes.append(node) 149 elif _is_derived_table(node) and isinstance( 150 node.parent, (exp.From, exp.Join, exp.Subquery) 151 ): 152 self._derived_tables.append(node) 153 elif isinstance(node, exp.UNWRAPPED_QUERIES): 154 self._subqueries.append(node) 155 156 self._collected = True 157 158 def _ensure_collected(self): 159 if not self._collected: 160 self._collect() 161 162 def walk(self, bfs=True, prune=None): 163 return walk_in_scope(self.expression, bfs=bfs, prune=None) 164 165 def find(self, *expression_types, bfs=True): 166 return find_in_scope(self.expression, expression_types, bfs=bfs) 167 168 def find_all(self, *expression_types, bfs=True): 169 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 170 171 def replace(self, old, new): 172 """ 173 Replace `old` with `new`. 174 175 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 176 177 Args: 178 old (exp.Expression): old node 179 new (exp.Expression): new node 180 """ 181 old.replace(new) 182 self.clear_cache() 183 184 @property 185 def tables(self): 186 """ 187 List of tables in this scope. 188 189 Returns: 190 list[exp.Table]: tables 191 """ 192 self._ensure_collected() 193 return self._tables 194 195 @property 196 def ctes(self): 197 """ 198 List of CTEs in this scope. 199 200 Returns: 201 list[exp.CTE]: ctes 202 """ 203 self._ensure_collected() 204 return self._ctes 205 206 @property 207 def derived_tables(self): 208 """ 209 List of derived tables in this scope. 210 211 For example: 212 SELECT * FROM (SELECT ...) <- that's a derived table 213 214 Returns: 215 list[exp.Subquery]: derived tables 216 """ 217 self._ensure_collected() 218 return self._derived_tables 219 220 @property 221 def udtfs(self): 222 """ 223 List of "User Defined Tabular Functions" in this scope. 224 225 Returns: 226 list[exp.UDTF]: UDTFs 227 """ 228 self._ensure_collected() 229 return self._udtfs 230 231 @property 232 def subqueries(self): 233 """ 234 List of subqueries in this scope. 235 236 For example: 237 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 238 239 Returns: 240 list[exp.Select | exp.SetOperation]: subqueries 241 """ 242 self._ensure_collected() 243 return self._subqueries 244 245 @property 246 def stars(self) -> t.List[exp.Column | exp.Dot]: 247 """ 248 List of star expressions (columns or dots) in this scope. 249 """ 250 self._ensure_collected() 251 return self._stars 252 253 @property 254 def columns(self): 255 """ 256 List of columns in this scope. 257 258 Returns: 259 list[exp.Column]: Column instances in this scope, plus any 260 Columns that reference this scope from correlated subqueries. 261 """ 262 if self._columns is None: 263 self._ensure_collected() 264 columns = self._raw_columns 265 266 external_columns = [ 267 column 268 for scope in itertools.chain( 269 self.subquery_scopes, 270 self.udtf_scopes, 271 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 272 ) 273 for column in scope.external_columns 274 ] 275 276 named_selects = set(self.expression.named_selects) 277 278 self._columns = [] 279 for column in columns + external_columns: 280 ancestor = column.find_ancestor( 281 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 282 ) 283 if ( 284 not ancestor 285 or column.table 286 or isinstance(ancestor, exp.Select) 287 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 288 or ( 289 isinstance(ancestor, exp.Order) 290 and ( 291 isinstance(ancestor.parent, exp.Window) 292 or column.name not in named_selects 293 ) 294 ) 295 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 296 ): 297 self._columns.append(column) 298 299 return self._columns 300 301 @property 302 def selected_sources(self): 303 """ 304 Mapping of nodes and sources that are actually selected from in this scope. 305 306 That is, all tables in a schema are selectable at any point. But a 307 table only becomes a selected source if it's included in a FROM or JOIN clause. 308 309 Returns: 310 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 311 """ 312 if self._selected_sources is None: 313 result = {} 314 315 for name, node in self.references: 316 if name in result: 317 raise OptimizeError(f"Alias already used: {name}") 318 if name in self.sources: 319 result[name] = (node, self.sources[name]) 320 321 self._selected_sources = result 322 return self._selected_sources 323 324 @property 325 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 326 if self._references is None: 327 self._references = [] 328 329 for table in self.tables: 330 self._references.append((table.alias_or_name, table)) 331 for expression in itertools.chain(self.derived_tables, self.udtfs): 332 self._references.append( 333 ( 334 expression.alias, 335 expression if expression.args.get("pivots") else expression.unnest(), 336 ) 337 ) 338 339 return self._references 340 341 @property 342 def external_columns(self): 343 """ 344 Columns that appear to reference sources in outer scopes. 345 346 Returns: 347 list[exp.Column]: Column instances that don't reference 348 sources in the current scope. 349 """ 350 if self._external_columns is None: 351 if isinstance(self.expression, exp.SetOperation): 352 left, right = self.union_scopes 353 self._external_columns = left.external_columns + right.external_columns 354 else: 355 self._external_columns = [ 356 c for c in self.columns if c.table not in self.selected_sources 357 ] 358 359 return self._external_columns 360 361 @property 362 def unqualified_columns(self): 363 """ 364 Unqualified columns in the current scope. 365 366 Returns: 367 list[exp.Column]: Unqualified columns 368 """ 369 return [c for c in self.columns if not c.table] 370 371 @property 372 def join_hints(self): 373 """ 374 Hints that exist in the scope that reference tables 375 376 Returns: 377 list[exp.JoinHint]: Join hints that are referenced within the scope 378 """ 379 if self._join_hints is None: 380 return [] 381 return self._join_hints 382 383 @property 384 def pivots(self): 385 if not self._pivots: 386 self._pivots = [ 387 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 388 ] 389 390 return self._pivots 391 392 def source_columns(self, source_name): 393 """ 394 Get all columns in the current scope for a particular source. 395 396 Args: 397 source_name (str): Name of the source 398 Returns: 399 list[exp.Column]: Column instances that reference `source_name` 400 """ 401 return [column for column in self.columns if column.table == source_name] 402 403 @property 404 def is_subquery(self): 405 """Determine if this scope is a subquery""" 406 return self.scope_type == ScopeType.SUBQUERY 407 408 @property 409 def is_derived_table(self): 410 """Determine if this scope is a derived table""" 411 return self.scope_type == ScopeType.DERIVED_TABLE 412 413 @property 414 def is_union(self): 415 """Determine if this scope is a union""" 416 return self.scope_type == ScopeType.UNION 417 418 @property 419 def is_cte(self): 420 """Determine if this scope is a common table expression""" 421 return self.scope_type == ScopeType.CTE 422 423 @property 424 def is_root(self): 425 """Determine if this is the root scope""" 426 return self.scope_type == ScopeType.ROOT 427 428 @property 429 def is_udtf(self): 430 """Determine if this scope is a UDTF (User Defined Table Function)""" 431 return self.scope_type == ScopeType.UDTF 432 433 @property 434 def is_correlated_subquery(self): 435 """Determine if this scope is a correlated subquery""" 436 return bool(self.can_be_correlated and self.external_columns) 437 438 def rename_source(self, old_name, new_name): 439 """Rename a source in this scope""" 440 columns = self.sources.pop(old_name or "", []) 441 self.sources[new_name] = columns 442 443 def add_source(self, name, source): 444 """Add a source to this scope""" 445 self.sources[name] = source 446 self.clear_cache() 447 448 def remove_source(self, name): 449 """Remove a source from this scope""" 450 self.sources.pop(name, None) 451 self.clear_cache() 452 453 def __repr__(self): 454 return f"Scope<{self.expression.sql()}>" 455 456 def traverse(self): 457 """ 458 Traverse the scope tree from this node. 459 460 Yields: 461 Scope: scope instances in depth-first-search post-order 462 """ 463 stack = [self] 464 result = [] 465 while stack: 466 scope = stack.pop() 467 result.append(scope) 468 stack.extend( 469 itertools.chain( 470 scope.cte_scopes, 471 scope.union_scopes, 472 scope.table_scopes, 473 scope.subquery_scopes, 474 ) 475 ) 476 477 yield from reversed(result) 478 479 def ref_count(self): 480 """ 481 Count the number of times each scope in this tree is referenced. 482 483 Returns: 484 dict[int, int]: Mapping of Scope instance ID to reference count 485 """ 486 scope_ref_count = defaultdict(lambda: 0) 487 488 for scope in self.traverse(): 489 for _, source in scope.selected_sources.values(): 490 scope_ref_count[id(source)] += 1 491 492 return scope_ref_count 493 494 495def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 496 """ 497 Traverse an expression by its "scopes". 498 499 "Scope" represents the current context of a Select statement. 500 501 This is helpful for optimizing queries, where we need more information than 502 the expression tree itself. For example, we might care about the source 503 names within a subquery. Returns a list because a generator could result in 504 incomplete properties which is confusing. 505 506 Examples: 507 >>> import sqlglot 508 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 509 >>> scopes = traverse_scope(expression) 510 >>> scopes[0].expression.sql(), list(scopes[0].sources) 511 ('SELECT a FROM x', ['x']) 512 >>> scopes[1].expression.sql(), list(scopes[1].sources) 513 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 514 515 Args: 516 expression: Expression to traverse 517 518 Returns: 519 A list of the created scope instances 520 """ 521 if isinstance(expression, TRAVERSABLES): 522 return list(_traverse_scope(Scope(expression))) 523 return [] 524 525 526def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 527 """ 528 Build a scope tree. 529 530 Args: 531 expression: Expression to build the scope tree for. 532 533 Returns: 534 The root scope 535 """ 536 return seq_get(traverse_scope(expression), -1) 537 538 539def _traverse_scope(scope): 540 expression = scope.expression 541 542 if isinstance(expression, exp.Select): 543 yield from _traverse_select(scope) 544 elif isinstance(expression, exp.SetOperation): 545 yield from _traverse_ctes(scope) 546 yield from _traverse_union(scope) 547 return 548 elif isinstance(expression, exp.Subquery): 549 if scope.is_root: 550 yield from _traverse_select(scope) 551 else: 552 yield from _traverse_subqueries(scope) 553 elif isinstance(expression, exp.Table): 554 yield from _traverse_tables(scope) 555 elif isinstance(expression, exp.UDTF): 556 yield from _traverse_udtfs(scope) 557 elif isinstance(expression, exp.DDL): 558 if isinstance(expression.expression, exp.Query): 559 yield from _traverse_ctes(scope) 560 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 561 return 562 elif isinstance(expression, exp.DML): 563 yield from _traverse_ctes(scope) 564 for query in find_all_in_scope(expression, exp.Query): 565 # This check ensures we don't yield the CTE/nested queries twice 566 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 567 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 568 return 569 else: 570 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 571 return 572 573 yield scope 574 575 576def _traverse_select(scope): 577 yield from _traverse_ctes(scope) 578 yield from _traverse_tables(scope) 579 yield from _traverse_subqueries(scope) 580 581 582def _traverse_union(scope): 583 prev_scope = None 584 union_scope_stack = [scope] 585 expression_stack = [scope.expression.right, scope.expression.left] 586 587 while expression_stack: 588 expression = expression_stack.pop() 589 union_scope = union_scope_stack[-1] 590 591 new_scope = union_scope.branch( 592 expression, 593 outer_columns=union_scope.outer_columns, 594 scope_type=ScopeType.UNION, 595 ) 596 597 if isinstance(expression, exp.SetOperation): 598 yield from _traverse_ctes(new_scope) 599 600 union_scope_stack.append(new_scope) 601 expression_stack.extend([expression.right, expression.left]) 602 continue 603 604 for scope in _traverse_scope(new_scope): 605 yield scope 606 607 if prev_scope: 608 union_scope_stack.pop() 609 union_scope.union_scopes = [prev_scope, scope] 610 prev_scope = union_scope 611 612 yield union_scope 613 else: 614 prev_scope = scope 615 616 617def _traverse_ctes(scope): 618 sources = {} 619 620 for cte in scope.ctes: 621 cte_name = cte.alias 622 623 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 624 # thus the recursive scope is the first section of the union. 625 with_ = scope.expression.args.get("with") 626 if with_ and with_.recursive: 627 union = cte.this 628 629 if isinstance(union, exp.SetOperation): 630 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 631 632 child_scope = None 633 634 for child_scope in _traverse_scope( 635 scope.branch( 636 cte.this, 637 cte_sources=sources, 638 outer_columns=cte.alias_column_names, 639 scope_type=ScopeType.CTE, 640 ) 641 ): 642 yield child_scope 643 644 # append the final child_scope yielded 645 if child_scope: 646 sources[cte_name] = child_scope 647 scope.cte_scopes.append(child_scope) 648 649 scope.sources.update(sources) 650 scope.cte_sources.update(sources) 651 652 653def _is_derived_table(expression: exp.Subquery) -> bool: 654 """ 655 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 656 as it doesn't introduce a new scope. If an alias is present, it shadows all names 657 under the Subquery, so that's one exception to this rule. 658 """ 659 return isinstance(expression, exp.Subquery) and bool( 660 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 661 ) 662 663 664def _traverse_tables(scope): 665 sources = {} 666 667 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 668 expressions = [] 669 from_ = scope.expression.args.get("from") 670 if from_: 671 expressions.append(from_.this) 672 673 for join in scope.expression.args.get("joins") or []: 674 expressions.append(join.this) 675 676 if isinstance(scope.expression, exp.Table): 677 expressions.append(scope.expression) 678 679 expressions.extend(scope.expression.args.get("laterals") or []) 680 681 for expression in expressions: 682 if isinstance(expression, exp.Final): 683 expression = expression.this 684 if isinstance(expression, exp.Table): 685 table_name = expression.name 686 source_name = expression.alias_or_name 687 688 if table_name in scope.sources and not expression.db: 689 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 690 # it is pivoted, because then we get back a new table and hence a new source. 691 pivots = expression.args.get("pivots") 692 if pivots: 693 sources[pivots[0].alias] = expression 694 else: 695 sources[source_name] = scope.sources[table_name] 696 elif source_name in sources: 697 sources[find_new_name(sources, table_name)] = expression 698 else: 699 sources[source_name] = expression 700 701 # Make sure to not include the joins twice 702 if expression is not scope.expression: 703 expressions.extend(join.this for join in expression.args.get("joins") or []) 704 705 continue 706 707 if not isinstance(expression, exp.DerivedTable): 708 continue 709 710 if isinstance(expression, exp.UDTF): 711 lateral_sources = sources 712 scope_type = ScopeType.UDTF 713 scopes = scope.udtf_scopes 714 elif _is_derived_table(expression): 715 lateral_sources = None 716 scope_type = ScopeType.DERIVED_TABLE 717 scopes = scope.derived_table_scopes 718 expressions.extend(join.this for join in expression.args.get("joins") or []) 719 else: 720 # Makes sure we check for possible sources in nested table constructs 721 expressions.append(expression.this) 722 expressions.extend(join.this for join in expression.args.get("joins") or []) 723 continue 724 725 for child_scope in _traverse_scope( 726 scope.branch( 727 expression, 728 lateral_sources=lateral_sources, 729 outer_columns=expression.alias_column_names, 730 scope_type=scope_type, 731 ) 732 ): 733 yield child_scope 734 735 # Tables without aliases will be set as "" 736 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 737 # Until then, this means that only a single, unaliased derived table is allowed (rather, 738 # the latest one wins. 739 sources[expression.alias] = child_scope 740 741 # append the final child_scope yielded 742 scopes.append(child_scope) 743 scope.table_scopes.append(child_scope) 744 745 scope.sources.update(sources) 746 747 748def _traverse_subqueries(scope): 749 for subquery in scope.subqueries: 750 top = None 751 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 752 yield child_scope 753 top = child_scope 754 scope.subquery_scopes.append(top) 755 756 757def _traverse_udtfs(scope): 758 if isinstance(scope.expression, exp.Unnest): 759 expressions = scope.expression.expressions 760 elif isinstance(scope.expression, exp.Lateral): 761 expressions = [scope.expression.this] 762 else: 763 expressions = [] 764 765 sources = {} 766 for expression in expressions: 767 if _is_derived_table(expression): 768 top = None 769 for child_scope in _traverse_scope( 770 scope.branch( 771 expression, 772 scope_type=ScopeType.SUBQUERY, 773 outer_columns=expression.alias_column_names, 774 ) 775 ): 776 yield child_scope 777 top = child_scope 778 sources[expression.alias] = child_scope 779 780 scope.subquery_scopes.append(top) 781 782 scope.sources.update(sources) 783 784 785def walk_in_scope(expression, bfs=True, prune=None): 786 """ 787 Returns a generator object which visits all nodes in the syntrax tree, stopping at 788 nodes that start child scopes. 789 790 Args: 791 expression (exp.Expression): 792 bfs (bool): if set to True the BFS traversal order will be applied, 793 otherwise the DFS traversal will be used instead. 794 prune ((node, parent, arg_key) -> bool): callable that returns True if 795 the generator should stop traversing this branch of the tree. 796 797 Yields: 798 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 799 """ 800 # We'll use this variable to pass state into the dfs generator. 801 # Whenever we set it to True, we exclude a subtree from traversal. 802 crossed_scope_boundary = False 803 804 for node in expression.walk( 805 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 806 ): 807 crossed_scope_boundary = False 808 809 yield node 810 811 if node is expression: 812 continue 813 if ( 814 isinstance(node, exp.CTE) 815 or ( 816 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 817 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 818 ) 819 or isinstance(node, exp.UNWRAPPED_QUERIES) 820 ): 821 crossed_scope_boundary = True 822 823 if isinstance(node, (exp.Subquery, exp.UDTF)): 824 # The following args are not actually in the inner scope, so we should visit them 825 for key in ("joins", "laterals", "pivots"): 826 for arg in node.args.get(key) or []: 827 yield from walk_in_scope(arg, bfs=bfs) 828 829 830def find_all_in_scope(expression, expression_types, bfs=True): 831 """ 832 Returns a generator object which visits all nodes in this scope and only yields those that 833 match at least one of the specified expression types. 834 835 This does NOT traverse into subscopes. 836 837 Args: 838 expression (exp.Expression): 839 expression_types (tuple[type]|type): the expression type(s) to match. 840 bfs (bool): True to use breadth-first search, False to use depth-first. 841 842 Yields: 843 exp.Expression: nodes 844 """ 845 for expression in walk_in_scope(expression, bfs=bfs): 846 if isinstance(expression, tuple(ensure_collection(expression_types))): 847 yield expression 848 849 850def find_in_scope(expression, expression_types, bfs=True): 851 """ 852 Returns the first node in this scope which matches at least one of the specified types. 853 854 This does NOT traverse into subscopes. 855 856 Args: 857 expression (exp.Expression): 858 expression_types (tuple[type]|type): the expression type(s) to match. 859 bfs (bool): True to use breadth-first search, False to use depth-first. 860 861 Returns: 862 exp.Expression: the node which matches the criteria or None if no node matching 863 the criteria was found. 864 """ 865 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 ) 120 121 def _collect(self): 122 self._tables = [] 123 self._ctes = [] 124 self._subqueries = [] 125 self._derived_tables = [] 126 self._udtfs = [] 127 self._raw_columns = [] 128 self._stars = [] 129 self._join_hints = [] 130 131 for node in self.walk(bfs=False): 132 if node is self.expression: 133 continue 134 135 if isinstance(node, exp.Dot) and node.is_star: 136 self._stars.append(node) 137 elif isinstance(node, exp.Column): 138 if isinstance(node.this, exp.Star): 139 self._stars.append(node) 140 else: 141 self._raw_columns.append(node) 142 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 143 self._tables.append(node) 144 elif isinstance(node, exp.JoinHint): 145 self._join_hints.append(node) 146 elif isinstance(node, exp.UDTF): 147 self._udtfs.append(node) 148 elif isinstance(node, exp.CTE): 149 self._ctes.append(node) 150 elif _is_derived_table(node) and isinstance( 151 node.parent, (exp.From, exp.Join, exp.Subquery) 152 ): 153 self._derived_tables.append(node) 154 elif isinstance(node, exp.UNWRAPPED_QUERIES): 155 self._subqueries.append(node) 156 157 self._collected = True 158 159 def _ensure_collected(self): 160 if not self._collected: 161 self._collect() 162 163 def walk(self, bfs=True, prune=None): 164 return walk_in_scope(self.expression, bfs=bfs, prune=None) 165 166 def find(self, *expression_types, bfs=True): 167 return find_in_scope(self.expression, expression_types, bfs=bfs) 168 169 def find_all(self, *expression_types, bfs=True): 170 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 171 172 def replace(self, old, new): 173 """ 174 Replace `old` with `new`. 175 176 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 177 178 Args: 179 old (exp.Expression): old node 180 new (exp.Expression): new node 181 """ 182 old.replace(new) 183 self.clear_cache() 184 185 @property 186 def tables(self): 187 """ 188 List of tables in this scope. 189 190 Returns: 191 list[exp.Table]: tables 192 """ 193 self._ensure_collected() 194 return self._tables 195 196 @property 197 def ctes(self): 198 """ 199 List of CTEs in this scope. 200 201 Returns: 202 list[exp.CTE]: ctes 203 """ 204 self._ensure_collected() 205 return self._ctes 206 207 @property 208 def derived_tables(self): 209 """ 210 List of derived tables in this scope. 211 212 For example: 213 SELECT * FROM (SELECT ...) <- that's a derived table 214 215 Returns: 216 list[exp.Subquery]: derived tables 217 """ 218 self._ensure_collected() 219 return self._derived_tables 220 221 @property 222 def udtfs(self): 223 """ 224 List of "User Defined Tabular Functions" in this scope. 225 226 Returns: 227 list[exp.UDTF]: UDTFs 228 """ 229 self._ensure_collected() 230 return self._udtfs 231 232 @property 233 def subqueries(self): 234 """ 235 List of subqueries in this scope. 236 237 For example: 238 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 239 240 Returns: 241 list[exp.Select | exp.SetOperation]: subqueries 242 """ 243 self._ensure_collected() 244 return self._subqueries 245 246 @property 247 def stars(self) -> t.List[exp.Column | exp.Dot]: 248 """ 249 List of star expressions (columns or dots) in this scope. 250 """ 251 self._ensure_collected() 252 return self._stars 253 254 @property 255 def columns(self): 256 """ 257 List of columns in this scope. 258 259 Returns: 260 list[exp.Column]: Column instances in this scope, plus any 261 Columns that reference this scope from correlated subqueries. 262 """ 263 if self._columns is None: 264 self._ensure_collected() 265 columns = self._raw_columns 266 267 external_columns = [ 268 column 269 for scope in itertools.chain( 270 self.subquery_scopes, 271 self.udtf_scopes, 272 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 273 ) 274 for column in scope.external_columns 275 ] 276 277 named_selects = set(self.expression.named_selects) 278 279 self._columns = [] 280 for column in columns + external_columns: 281 ancestor = column.find_ancestor( 282 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 283 ) 284 if ( 285 not ancestor 286 or column.table 287 or isinstance(ancestor, exp.Select) 288 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 289 or ( 290 isinstance(ancestor, exp.Order) 291 and ( 292 isinstance(ancestor.parent, exp.Window) 293 or column.name not in named_selects 294 ) 295 ) 296 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 297 ): 298 self._columns.append(column) 299 300 return self._columns 301 302 @property 303 def selected_sources(self): 304 """ 305 Mapping of nodes and sources that are actually selected from in this scope. 306 307 That is, all tables in a schema are selectable at any point. But a 308 table only becomes a selected source if it's included in a FROM or JOIN clause. 309 310 Returns: 311 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 312 """ 313 if self._selected_sources is None: 314 result = {} 315 316 for name, node in self.references: 317 if name in result: 318 raise OptimizeError(f"Alias already used: {name}") 319 if name in self.sources: 320 result[name] = (node, self.sources[name]) 321 322 self._selected_sources = result 323 return self._selected_sources 324 325 @property 326 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 327 if self._references is None: 328 self._references = [] 329 330 for table in self.tables: 331 self._references.append((table.alias_or_name, table)) 332 for expression in itertools.chain(self.derived_tables, self.udtfs): 333 self._references.append( 334 ( 335 expression.alias, 336 expression if expression.args.get("pivots") else expression.unnest(), 337 ) 338 ) 339 340 return self._references 341 342 @property 343 def external_columns(self): 344 """ 345 Columns that appear to reference sources in outer scopes. 346 347 Returns: 348 list[exp.Column]: Column instances that don't reference 349 sources in the current scope. 350 """ 351 if self._external_columns is None: 352 if isinstance(self.expression, exp.SetOperation): 353 left, right = self.union_scopes 354 self._external_columns = left.external_columns + right.external_columns 355 else: 356 self._external_columns = [ 357 c for c in self.columns if c.table not in self.selected_sources 358 ] 359 360 return self._external_columns 361 362 @property 363 def unqualified_columns(self): 364 """ 365 Unqualified columns in the current scope. 366 367 Returns: 368 list[exp.Column]: Unqualified columns 369 """ 370 return [c for c in self.columns if not c.table] 371 372 @property 373 def join_hints(self): 374 """ 375 Hints that exist in the scope that reference tables 376 377 Returns: 378 list[exp.JoinHint]: Join hints that are referenced within the scope 379 """ 380 if self._join_hints is None: 381 return [] 382 return self._join_hints 383 384 @property 385 def pivots(self): 386 if not self._pivots: 387 self._pivots = [ 388 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 389 ] 390 391 return self._pivots 392 393 def source_columns(self, source_name): 394 """ 395 Get all columns in the current scope for a particular source. 396 397 Args: 398 source_name (str): Name of the source 399 Returns: 400 list[exp.Column]: Column instances that reference `source_name` 401 """ 402 return [column for column in self.columns if column.table == source_name] 403 404 @property 405 def is_subquery(self): 406 """Determine if this scope is a subquery""" 407 return self.scope_type == ScopeType.SUBQUERY 408 409 @property 410 def is_derived_table(self): 411 """Determine if this scope is a derived table""" 412 return self.scope_type == ScopeType.DERIVED_TABLE 413 414 @property 415 def is_union(self): 416 """Determine if this scope is a union""" 417 return self.scope_type == ScopeType.UNION 418 419 @property 420 def is_cte(self): 421 """Determine if this scope is a common table expression""" 422 return self.scope_type == ScopeType.CTE 423 424 @property 425 def is_root(self): 426 """Determine if this is the root scope""" 427 return self.scope_type == ScopeType.ROOT 428 429 @property 430 def is_udtf(self): 431 """Determine if this scope is a UDTF (User Defined Table Function)""" 432 return self.scope_type == ScopeType.UDTF 433 434 @property 435 def is_correlated_subquery(self): 436 """Determine if this scope is a correlated subquery""" 437 return bool(self.can_be_correlated and self.external_columns) 438 439 def rename_source(self, old_name, new_name): 440 """Rename a source in this scope""" 441 columns = self.sources.pop(old_name or "", []) 442 self.sources[new_name] = columns 443 444 def add_source(self, name, source): 445 """Add a source to this scope""" 446 self.sources[name] = source 447 self.clear_cache() 448 449 def remove_source(self, name): 450 """Remove a source from this scope""" 451 self.sources.pop(name, None) 452 self.clear_cache() 453 454 def __repr__(self): 455 return f"Scope<{self.expression.sql()}>" 456 457 def traverse(self): 458 """ 459 Traverse the scope tree from this node. 460 461 Yields: 462 Scope: scope instances in depth-first-search post-order 463 """ 464 stack = [self] 465 result = [] 466 while stack: 467 scope = stack.pop() 468 result.append(scope) 469 stack.extend( 470 itertools.chain( 471 scope.cte_scopes, 472 scope.union_scopes, 473 scope.table_scopes, 474 scope.subquery_scopes, 475 ) 476 ) 477 478 yield from reversed(result) 479 480 def ref_count(self): 481 """ 482 Count the number of times each scope in this tree is referenced. 483 484 Returns: 485 dict[int, int]: Mapping of Scope instance ID to reference count 486 """ 487 scope_ref_count = defaultdict(lambda: 0) 488 489 for scope in self.traverse(): 490 for _, source in scope.selected_sources.values(): 491 scope_ref_count[id(source)] += 1 492 493 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None
105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 )
Branch from the current scope to a new, inner scope
172 def replace(self, old, new): 173 """ 174 Replace `old` with `new`. 175 176 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 177 178 Args: 179 old (exp.Expression): old node 180 new (exp.Expression): new node 181 """ 182 old.replace(new) 183 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
185 @property 186 def tables(self): 187 """ 188 List of tables in this scope. 189 190 Returns: 191 list[exp.Table]: tables 192 """ 193 self._ensure_collected() 194 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
196 @property 197 def ctes(self): 198 """ 199 List of CTEs in this scope. 200 201 Returns: 202 list[exp.CTE]: ctes 203 """ 204 self._ensure_collected() 205 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
207 @property 208 def derived_tables(self): 209 """ 210 List of derived tables in this scope. 211 212 For example: 213 SELECT * FROM (SELECT ...) <- that's a derived table 214 215 Returns: 216 list[exp.Subquery]: derived tables 217 """ 218 self._ensure_collected() 219 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
221 @property 222 def udtfs(self): 223 """ 224 List of "User Defined Tabular Functions" in this scope. 225 226 Returns: 227 list[exp.UDTF]: UDTFs 228 """ 229 self._ensure_collected() 230 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
232 @property 233 def subqueries(self): 234 """ 235 List of subqueries in this scope. 236 237 For example: 238 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 239 240 Returns: 241 list[exp.Select | exp.SetOperation]: subqueries 242 """ 243 self._ensure_collected() 244 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
246 @property 247 def stars(self) -> t.List[exp.Column | exp.Dot]: 248 """ 249 List of star expressions (columns or dots) in this scope. 250 """ 251 self._ensure_collected() 252 return self._stars
List of star expressions (columns or dots) in this scope.
254 @property 255 def columns(self): 256 """ 257 List of columns in this scope. 258 259 Returns: 260 list[exp.Column]: Column instances in this scope, plus any 261 Columns that reference this scope from correlated subqueries. 262 """ 263 if self._columns is None: 264 self._ensure_collected() 265 columns = self._raw_columns 266 267 external_columns = [ 268 column 269 for scope in itertools.chain( 270 self.subquery_scopes, 271 self.udtf_scopes, 272 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 273 ) 274 for column in scope.external_columns 275 ] 276 277 named_selects = set(self.expression.named_selects) 278 279 self._columns = [] 280 for column in columns + external_columns: 281 ancestor = column.find_ancestor( 282 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 283 ) 284 if ( 285 not ancestor 286 or column.table 287 or isinstance(ancestor, exp.Select) 288 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 289 or ( 290 isinstance(ancestor, exp.Order) 291 and ( 292 isinstance(ancestor.parent, exp.Window) 293 or column.name not in named_selects 294 ) 295 ) 296 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 297 ): 298 self._columns.append(column) 299 300 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
302 @property 303 def selected_sources(self): 304 """ 305 Mapping of nodes and sources that are actually selected from in this scope. 306 307 That is, all tables in a schema are selectable at any point. But a 308 table only becomes a selected source if it's included in a FROM or JOIN clause. 309 310 Returns: 311 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 312 """ 313 if self._selected_sources is None: 314 result = {} 315 316 for name, node in self.references: 317 if name in result: 318 raise OptimizeError(f"Alias already used: {name}") 319 if name in self.sources: 320 result[name] = (node, self.sources[name]) 321 322 self._selected_sources = result 323 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
325 @property 326 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 327 if self._references is None: 328 self._references = [] 329 330 for table in self.tables: 331 self._references.append((table.alias_or_name, table)) 332 for expression in itertools.chain(self.derived_tables, self.udtfs): 333 self._references.append( 334 ( 335 expression.alias, 336 expression if expression.args.get("pivots") else expression.unnest(), 337 ) 338 ) 339 340 return self._references
342 @property 343 def external_columns(self): 344 """ 345 Columns that appear to reference sources in outer scopes. 346 347 Returns: 348 list[exp.Column]: Column instances that don't reference 349 sources in the current scope. 350 """ 351 if self._external_columns is None: 352 if isinstance(self.expression, exp.SetOperation): 353 left, right = self.union_scopes 354 self._external_columns = left.external_columns + right.external_columns 355 else: 356 self._external_columns = [ 357 c for c in self.columns if c.table not in self.selected_sources 358 ] 359 360 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
362 @property 363 def unqualified_columns(self): 364 """ 365 Unqualified columns in the current scope. 366 367 Returns: 368 list[exp.Column]: Unqualified columns 369 """ 370 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
372 @property 373 def join_hints(self): 374 """ 375 Hints that exist in the scope that reference tables 376 377 Returns: 378 list[exp.JoinHint]: Join hints that are referenced within the scope 379 """ 380 if self._join_hints is None: 381 return [] 382 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
393 def source_columns(self, source_name): 394 """ 395 Get all columns in the current scope for a particular source. 396 397 Args: 398 source_name (str): Name of the source 399 Returns: 400 list[exp.Column]: Column instances that reference `source_name` 401 """ 402 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
404 @property 405 def is_subquery(self): 406 """Determine if this scope is a subquery""" 407 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
409 @property 410 def is_derived_table(self): 411 """Determine if this scope is a derived table""" 412 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
414 @property 415 def is_union(self): 416 """Determine if this scope is a union""" 417 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
419 @property 420 def is_cte(self): 421 """Determine if this scope is a common table expression""" 422 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
424 @property 425 def is_root(self): 426 """Determine if this is the root scope""" 427 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
429 @property 430 def is_udtf(self): 431 """Determine if this scope is a UDTF (User Defined Table Function)""" 432 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
439 def rename_source(self, old_name, new_name): 440 """Rename a source in this scope""" 441 columns = self.sources.pop(old_name or "", []) 442 self.sources[new_name] = columns
Rename a source in this scope
444 def add_source(self, name, source): 445 """Add a source to this scope""" 446 self.sources[name] = source 447 self.clear_cache()
Add a source to this scope
449 def remove_source(self, name): 450 """Remove a source from this scope""" 451 self.sources.pop(name, None) 452 self.clear_cache()
Remove a source from this scope
457 def traverse(self): 458 """ 459 Traverse the scope tree from this node. 460 461 Yields: 462 Scope: scope instances in depth-first-search post-order 463 """ 464 stack = [self] 465 result = [] 466 while stack: 467 scope = stack.pop() 468 result.append(scope) 469 stack.extend( 470 itertools.chain( 471 scope.cte_scopes, 472 scope.union_scopes, 473 scope.table_scopes, 474 scope.subquery_scopes, 475 ) 476 ) 477 478 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
480 def ref_count(self): 481 """ 482 Count the number of times each scope in this tree is referenced. 483 484 Returns: 485 dict[int, int]: Mapping of Scope instance ID to reference count 486 """ 487 scope_ref_count = defaultdict(lambda: 0) 488 489 for scope in self.traverse(): 490 for _, source in scope.selected_sources.values(): 491 scope_ref_count[id(source)] += 1 492 493 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
496def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 497 """ 498 Traverse an expression by its "scopes". 499 500 "Scope" represents the current context of a Select statement. 501 502 This is helpful for optimizing queries, where we need more information than 503 the expression tree itself. For example, we might care about the source 504 names within a subquery. Returns a list because a generator could result in 505 incomplete properties which is confusing. 506 507 Examples: 508 >>> import sqlglot 509 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 510 >>> scopes = traverse_scope(expression) 511 >>> scopes[0].expression.sql(), list(scopes[0].sources) 512 ('SELECT a FROM x', ['x']) 513 >>> scopes[1].expression.sql(), list(scopes[1].sources) 514 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 515 516 Args: 517 expression: Expression to traverse 518 519 Returns: 520 A list of the created scope instances 521 """ 522 if isinstance(expression, TRAVERSABLES): 523 return list(_traverse_scope(Scope(expression))) 524 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
527def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 528 """ 529 Build a scope tree. 530 531 Args: 532 expression: Expression to build the scope tree for. 533 534 Returns: 535 The root scope 536 """ 537 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
786def walk_in_scope(expression, bfs=True, prune=None): 787 """ 788 Returns a generator object which visits all nodes in the syntrax tree, stopping at 789 nodes that start child scopes. 790 791 Args: 792 expression (exp.Expression): 793 bfs (bool): if set to True the BFS traversal order will be applied, 794 otherwise the DFS traversal will be used instead. 795 prune ((node, parent, arg_key) -> bool): callable that returns True if 796 the generator should stop traversing this branch of the tree. 797 798 Yields: 799 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 800 """ 801 # We'll use this variable to pass state into the dfs generator. 802 # Whenever we set it to True, we exclude a subtree from traversal. 803 crossed_scope_boundary = False 804 805 for node in expression.walk( 806 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 807 ): 808 crossed_scope_boundary = False 809 810 yield node 811 812 if node is expression: 813 continue 814 if ( 815 isinstance(node, exp.CTE) 816 or ( 817 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 818 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 819 ) 820 or isinstance(node, exp.UNWRAPPED_QUERIES) 821 ): 822 crossed_scope_boundary = True 823 824 if isinstance(node, (exp.Subquery, exp.UDTF)): 825 # The following args are not actually in the inner scope, so we should visit them 826 for key in ("joins", "laterals", "pivots"): 827 for arg in node.args.get(key) or []: 828 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
831def find_all_in_scope(expression, expression_types, bfs=True): 832 """ 833 Returns a generator object which visits all nodes in this scope and only yields those that 834 match at least one of the specified expression types. 835 836 This does NOT traverse into subscopes. 837 838 Args: 839 expression (exp.Expression): 840 expression_types (tuple[type]|type): the expression type(s) to match. 841 bfs (bool): True to use breadth-first search, False to use depth-first. 842 843 Yields: 844 exp.Expression: nodes 845 """ 846 for expression in walk_in_scope(expression, bfs=bfs): 847 if isinstance(expression, tuple(ensure_collection(expression_types))): 848 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
851def find_in_scope(expression, expression_types, bfs=True): 852 """ 853 Returns the first node in this scope which matches at least one of the specified types. 854 855 This does NOT traverse into subscopes. 856 857 Args: 858 expression (exp.Expression): 859 expression_types (tuple[type]|type): the expression type(s) to match. 860 bfs (bool): True to use breadth-first search, False to use depth-first. 861 862 Returns: 863 exp.Expression: the node which matches the criteria or None if no node matching 864 the criteria was found. 865 """ 866 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.