sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._stars = None 92 self._derived_tables = None 93 self._udtfs = None 94 self._tables = None 95 self._ctes = None 96 self._subqueries = None 97 self._selected_sources = None 98 self._columns = None 99 self._external_columns = None 100 self._join_hints = None 101 self._pivots = None 102 self._references = None 103 self._semi_anti_join_tables = None 104 105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 ) 120 121 def _collect(self): 122 self._tables = [] 123 self._ctes = [] 124 self._subqueries = [] 125 self._derived_tables = [] 126 self._udtfs = [] 127 self._raw_columns = [] 128 self._stars = [] 129 self._join_hints = [] 130 self._semi_anti_join_tables = set() 131 132 for node in self.walk(bfs=False): 133 if node is self.expression: 134 continue 135 136 if isinstance(node, exp.Dot) and node.is_star: 137 self._stars.append(node) 138 elif isinstance(node, exp.Column): 139 if isinstance(node.this, exp.Star): 140 self._stars.append(node) 141 else: 142 self._raw_columns.append(node) 143 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 144 parent = node.parent 145 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 146 self._semi_anti_join_tables.add(node.alias_or_name) 147 148 self._tables.append(node) 149 elif isinstance(node, exp.JoinHint): 150 self._join_hints.append(node) 151 elif isinstance(node, exp.UDTF): 152 self._udtfs.append(node) 153 elif isinstance(node, exp.CTE): 154 self._ctes.append(node) 155 elif _is_derived_table(node) and _is_from_or_join(node): 156 self._derived_tables.append(node) 157 elif isinstance(node, exp.UNWRAPPED_QUERIES): 158 self._subqueries.append(node) 159 160 self._collected = True 161 162 def _ensure_collected(self): 163 if not self._collected: 164 self._collect() 165 166 def walk(self, bfs=True, prune=None): 167 return walk_in_scope(self.expression, bfs=bfs, prune=None) 168 169 def find(self, *expression_types, bfs=True): 170 return find_in_scope(self.expression, expression_types, bfs=bfs) 171 172 def find_all(self, *expression_types, bfs=True): 173 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 174 175 def replace(self, old, new): 176 """ 177 Replace `old` with `new`. 178 179 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 180 181 Args: 182 old (exp.Expression): old node 183 new (exp.Expression): new node 184 """ 185 old.replace(new) 186 self.clear_cache() 187 188 @property 189 def tables(self): 190 """ 191 List of tables in this scope. 192 193 Returns: 194 list[exp.Table]: tables 195 """ 196 self._ensure_collected() 197 return self._tables 198 199 @property 200 def ctes(self): 201 """ 202 List of CTEs in this scope. 203 204 Returns: 205 list[exp.CTE]: ctes 206 """ 207 self._ensure_collected() 208 return self._ctes 209 210 @property 211 def derived_tables(self): 212 """ 213 List of derived tables in this scope. 214 215 For example: 216 SELECT * FROM (SELECT ...) <- that's a derived table 217 218 Returns: 219 list[exp.Subquery]: derived tables 220 """ 221 self._ensure_collected() 222 return self._derived_tables 223 224 @property 225 def udtfs(self): 226 """ 227 List of "User Defined Tabular Functions" in this scope. 228 229 Returns: 230 list[exp.UDTF]: UDTFs 231 """ 232 self._ensure_collected() 233 return self._udtfs 234 235 @property 236 def subqueries(self): 237 """ 238 List of subqueries in this scope. 239 240 For example: 241 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 242 243 Returns: 244 list[exp.Select | exp.SetOperation]: subqueries 245 """ 246 self._ensure_collected() 247 return self._subqueries 248 249 @property 250 def stars(self) -> t.List[exp.Column | exp.Dot]: 251 """ 252 List of star expressions (columns or dots) in this scope. 253 """ 254 self._ensure_collected() 255 return self._stars 256 257 @property 258 def columns(self): 259 """ 260 List of columns in this scope. 261 262 Returns: 263 list[exp.Column]: Column instances in this scope, plus any 264 Columns that reference this scope from correlated subqueries. 265 """ 266 if self._columns is None: 267 self._ensure_collected() 268 columns = self._raw_columns 269 270 external_columns = [ 271 column 272 for scope in itertools.chain( 273 self.subquery_scopes, 274 self.udtf_scopes, 275 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 276 ) 277 for column in scope.external_columns 278 ] 279 280 named_selects = set(self.expression.named_selects) 281 282 self._columns = [] 283 for column in columns + external_columns: 284 ancestor = column.find_ancestor( 285 exp.Select, 286 exp.Qualify, 287 exp.Order, 288 exp.Having, 289 exp.Hint, 290 exp.Table, 291 exp.Star, 292 exp.Distinct, 293 ) 294 if ( 295 not ancestor 296 or column.table 297 or isinstance(ancestor, exp.Select) 298 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 299 or ( 300 isinstance(ancestor, (exp.Order, exp.Distinct)) 301 and ( 302 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 303 or column.name not in named_selects 304 ) 305 ) 306 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 307 ): 308 self._columns.append(column) 309 310 return self._columns 311 312 @property 313 def selected_sources(self): 314 """ 315 Mapping of nodes and sources that are actually selected from in this scope. 316 317 That is, all tables in a schema are selectable at any point. But a 318 table only becomes a selected source if it's included in a FROM or JOIN clause. 319 320 Returns: 321 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 322 """ 323 if self._selected_sources is None: 324 result = {} 325 326 for name, node in self.references: 327 if name in self._semi_anti_join_tables: 328 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 329 # selected source 330 continue 331 332 if name in result: 333 raise OptimizeError(f"Alias already used: {name}") 334 if name in self.sources: 335 result[name] = (node, self.sources[name]) 336 337 self._selected_sources = result 338 return self._selected_sources 339 340 @property 341 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 342 if self._references is None: 343 self._references = [] 344 345 for table in self.tables: 346 self._references.append((table.alias_or_name, table)) 347 for expression in itertools.chain(self.derived_tables, self.udtfs): 348 self._references.append( 349 ( 350 expression.alias, 351 expression if expression.args.get("pivots") else expression.unnest(), 352 ) 353 ) 354 355 return self._references 356 357 @property 358 def external_columns(self): 359 """ 360 Columns that appear to reference sources in outer scopes. 361 362 Returns: 363 list[exp.Column]: Column instances that don't reference 364 sources in the current scope. 365 """ 366 if self._external_columns is None: 367 if isinstance(self.expression, exp.SetOperation): 368 left, right = self.union_scopes 369 self._external_columns = left.external_columns + right.external_columns 370 else: 371 self._external_columns = [ 372 c 373 for c in self.columns 374 if c.table not in self.selected_sources 375 and c.table not in self.semi_or_anti_join_tables 376 ] 377 378 return self._external_columns 379 380 @property 381 def unqualified_columns(self): 382 """ 383 Unqualified columns in the current scope. 384 385 Returns: 386 list[exp.Column]: Unqualified columns 387 """ 388 return [c for c in self.columns if not c.table] 389 390 @property 391 def join_hints(self): 392 """ 393 Hints that exist in the scope that reference tables 394 395 Returns: 396 list[exp.JoinHint]: Join hints that are referenced within the scope 397 """ 398 if self._join_hints is None: 399 return [] 400 return self._join_hints 401 402 @property 403 def pivots(self): 404 if not self._pivots: 405 self._pivots = [ 406 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 407 ] 408 409 return self._pivots 410 411 @property 412 def semi_or_anti_join_tables(self): 413 return self._semi_anti_join_tables or set() 414 415 def source_columns(self, source_name): 416 """ 417 Get all columns in the current scope for a particular source. 418 419 Args: 420 source_name (str): Name of the source 421 Returns: 422 list[exp.Column]: Column instances that reference `source_name` 423 """ 424 return [column for column in self.columns if column.table == source_name] 425 426 @property 427 def is_subquery(self): 428 """Determine if this scope is a subquery""" 429 return self.scope_type == ScopeType.SUBQUERY 430 431 @property 432 def is_derived_table(self): 433 """Determine if this scope is a derived table""" 434 return self.scope_type == ScopeType.DERIVED_TABLE 435 436 @property 437 def is_union(self): 438 """Determine if this scope is a union""" 439 return self.scope_type == ScopeType.UNION 440 441 @property 442 def is_cte(self): 443 """Determine if this scope is a common table expression""" 444 return self.scope_type == ScopeType.CTE 445 446 @property 447 def is_root(self): 448 """Determine if this is the root scope""" 449 return self.scope_type == ScopeType.ROOT 450 451 @property 452 def is_udtf(self): 453 """Determine if this scope is a UDTF (User Defined Table Function)""" 454 return self.scope_type == ScopeType.UDTF 455 456 @property 457 def is_correlated_subquery(self): 458 """Determine if this scope is a correlated subquery""" 459 return bool(self.can_be_correlated and self.external_columns) 460 461 def rename_source(self, old_name, new_name): 462 """Rename a source in this scope""" 463 columns = self.sources.pop(old_name or "", []) 464 self.sources[new_name] = columns 465 466 def add_source(self, name, source): 467 """Add a source to this scope""" 468 self.sources[name] = source 469 self.clear_cache() 470 471 def remove_source(self, name): 472 """Remove a source from this scope""" 473 self.sources.pop(name, None) 474 self.clear_cache() 475 476 def __repr__(self): 477 return f"Scope<{self.expression.sql()}>" 478 479 def traverse(self): 480 """ 481 Traverse the scope tree from this node. 482 483 Yields: 484 Scope: scope instances in depth-first-search post-order 485 """ 486 stack = [self] 487 result = [] 488 while stack: 489 scope = stack.pop() 490 result.append(scope) 491 stack.extend( 492 itertools.chain( 493 scope.cte_scopes, 494 scope.union_scopes, 495 scope.table_scopes, 496 scope.subquery_scopes, 497 ) 498 ) 499 500 yield from reversed(result) 501 502 def ref_count(self): 503 """ 504 Count the number of times each scope in this tree is referenced. 505 506 Returns: 507 dict[int, int]: Mapping of Scope instance ID to reference count 508 """ 509 scope_ref_count = defaultdict(lambda: 0) 510 511 for scope in self.traverse(): 512 for _, source in scope.selected_sources.values(): 513 scope_ref_count[id(source)] += 1 514 515 return scope_ref_count 516 517 518def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 519 """ 520 Traverse an expression by its "scopes". 521 522 "Scope" represents the current context of a Select statement. 523 524 This is helpful for optimizing queries, where we need more information than 525 the expression tree itself. For example, we might care about the source 526 names within a subquery. Returns a list because a generator could result in 527 incomplete properties which is confusing. 528 529 Examples: 530 >>> import sqlglot 531 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 532 >>> scopes = traverse_scope(expression) 533 >>> scopes[0].expression.sql(), list(scopes[0].sources) 534 ('SELECT a FROM x', ['x']) 535 >>> scopes[1].expression.sql(), list(scopes[1].sources) 536 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 537 538 Args: 539 expression: Expression to traverse 540 541 Returns: 542 A list of the created scope instances 543 """ 544 if isinstance(expression, TRAVERSABLES): 545 return list(_traverse_scope(Scope(expression))) 546 return [] 547 548 549def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 550 """ 551 Build a scope tree. 552 553 Args: 554 expression: Expression to build the scope tree for. 555 556 Returns: 557 The root scope 558 """ 559 return seq_get(traverse_scope(expression), -1) 560 561 562def _traverse_scope(scope): 563 expression = scope.expression 564 565 if isinstance(expression, exp.Select): 566 yield from _traverse_select(scope) 567 elif isinstance(expression, exp.SetOperation): 568 yield from _traverse_ctes(scope) 569 yield from _traverse_union(scope) 570 return 571 elif isinstance(expression, exp.Subquery): 572 if scope.is_root: 573 yield from _traverse_select(scope) 574 else: 575 yield from _traverse_subqueries(scope) 576 elif isinstance(expression, exp.Table): 577 yield from _traverse_tables(scope) 578 elif isinstance(expression, exp.UDTF): 579 yield from _traverse_udtfs(scope) 580 elif isinstance(expression, exp.DDL): 581 if isinstance(expression.expression, exp.Query): 582 yield from _traverse_ctes(scope) 583 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 584 return 585 elif isinstance(expression, exp.DML): 586 yield from _traverse_ctes(scope) 587 for query in find_all_in_scope(expression, exp.Query): 588 # This check ensures we don't yield the CTE/nested queries twice 589 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 590 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 591 return 592 else: 593 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 594 return 595 596 yield scope 597 598 599def _traverse_select(scope): 600 yield from _traverse_ctes(scope) 601 yield from _traverse_tables(scope) 602 yield from _traverse_subqueries(scope) 603 604 605def _traverse_union(scope): 606 prev_scope = None 607 union_scope_stack = [scope] 608 expression_stack = [scope.expression.right, scope.expression.left] 609 610 while expression_stack: 611 expression = expression_stack.pop() 612 union_scope = union_scope_stack[-1] 613 614 new_scope = union_scope.branch( 615 expression, 616 outer_columns=union_scope.outer_columns, 617 scope_type=ScopeType.UNION, 618 ) 619 620 if isinstance(expression, exp.SetOperation): 621 yield from _traverse_ctes(new_scope) 622 623 union_scope_stack.append(new_scope) 624 expression_stack.extend([expression.right, expression.left]) 625 continue 626 627 for scope in _traverse_scope(new_scope): 628 yield scope 629 630 if prev_scope: 631 union_scope_stack.pop() 632 union_scope.union_scopes = [prev_scope, scope] 633 prev_scope = union_scope 634 635 yield union_scope 636 else: 637 prev_scope = scope 638 639 640def _traverse_ctes(scope): 641 sources = {} 642 643 for cte in scope.ctes: 644 cte_name = cte.alias 645 646 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 647 # thus the recursive scope is the first section of the union. 648 with_ = scope.expression.args.get("with") 649 if with_ and with_.recursive: 650 union = cte.this 651 652 if isinstance(union, exp.SetOperation): 653 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 654 655 child_scope = None 656 657 for child_scope in _traverse_scope( 658 scope.branch( 659 cte.this, 660 cte_sources=sources, 661 outer_columns=cte.alias_column_names, 662 scope_type=ScopeType.CTE, 663 ) 664 ): 665 yield child_scope 666 667 # append the final child_scope yielded 668 if child_scope: 669 sources[cte_name] = child_scope 670 scope.cte_scopes.append(child_scope) 671 672 scope.sources.update(sources) 673 scope.cte_sources.update(sources) 674 675 676def _is_derived_table(expression: exp.Subquery) -> bool: 677 """ 678 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 679 as it doesn't introduce a new scope. If an alias is present, it shadows all names 680 under the Subquery, so that's one exception to this rule. 681 """ 682 return isinstance(expression, exp.Subquery) and bool( 683 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 684 ) 685 686 687def _is_from_or_join(expression: exp.Expression) -> bool: 688 """ 689 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 690 """ 691 parent = expression.parent 692 693 # Subqueries can be arbitrarily nested 694 while isinstance(parent, exp.Subquery): 695 parent = parent.parent 696 697 return isinstance(parent, (exp.From, exp.Join)) 698 699 700def _traverse_tables(scope): 701 sources = {} 702 703 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 704 expressions = [] 705 from_ = scope.expression.args.get("from") 706 if from_: 707 expressions.append(from_.this) 708 709 for join in scope.expression.args.get("joins") or []: 710 expressions.append(join.this) 711 712 if isinstance(scope.expression, exp.Table): 713 expressions.append(scope.expression) 714 715 expressions.extend(scope.expression.args.get("laterals") or []) 716 717 for expression in expressions: 718 if isinstance(expression, exp.Final): 719 expression = expression.this 720 if isinstance(expression, exp.Table): 721 table_name = expression.name 722 source_name = expression.alias_or_name 723 724 if table_name in scope.sources and not expression.db: 725 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 726 # it is pivoted, because then we get back a new table and hence a new source. 727 pivots = expression.args.get("pivots") 728 if pivots: 729 sources[pivots[0].alias] = expression 730 else: 731 sources[source_name] = scope.sources[table_name] 732 elif source_name in sources: 733 sources[find_new_name(sources, table_name)] = expression 734 else: 735 sources[source_name] = expression 736 737 # Make sure to not include the joins twice 738 if expression is not scope.expression: 739 expressions.extend(join.this for join in expression.args.get("joins") or []) 740 741 continue 742 743 if not isinstance(expression, exp.DerivedTable): 744 continue 745 746 if isinstance(expression, exp.UDTF): 747 lateral_sources = sources 748 scope_type = ScopeType.UDTF 749 scopes = scope.udtf_scopes 750 elif _is_derived_table(expression): 751 lateral_sources = None 752 scope_type = ScopeType.DERIVED_TABLE 753 scopes = scope.derived_table_scopes 754 expressions.extend(join.this for join in expression.args.get("joins") or []) 755 else: 756 # Makes sure we check for possible sources in nested table constructs 757 expressions.append(expression.this) 758 expressions.extend(join.this for join in expression.args.get("joins") or []) 759 continue 760 761 for child_scope in _traverse_scope( 762 scope.branch( 763 expression, 764 lateral_sources=lateral_sources, 765 outer_columns=expression.alias_column_names, 766 scope_type=scope_type, 767 ) 768 ): 769 yield child_scope 770 771 # Tables without aliases will be set as "" 772 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 773 # Until then, this means that only a single, unaliased derived table is allowed (rather, 774 # the latest one wins. 775 sources[expression.alias] = child_scope 776 777 # append the final child_scope yielded 778 scopes.append(child_scope) 779 scope.table_scopes.append(child_scope) 780 781 scope.sources.update(sources) 782 783 784def _traverse_subqueries(scope): 785 for subquery in scope.subqueries: 786 top = None 787 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 788 yield child_scope 789 top = child_scope 790 scope.subquery_scopes.append(top) 791 792 793def _traverse_udtfs(scope): 794 if isinstance(scope.expression, exp.Unnest): 795 expressions = scope.expression.expressions 796 elif isinstance(scope.expression, exp.Lateral): 797 expressions = [scope.expression.this] 798 else: 799 expressions = [] 800 801 sources = {} 802 for expression in expressions: 803 if _is_derived_table(expression): 804 top = None 805 for child_scope in _traverse_scope( 806 scope.branch( 807 expression, 808 scope_type=ScopeType.SUBQUERY, 809 outer_columns=expression.alias_column_names, 810 ) 811 ): 812 yield child_scope 813 top = child_scope 814 sources[expression.alias] = child_scope 815 816 scope.subquery_scopes.append(top) 817 818 scope.sources.update(sources) 819 820 821def walk_in_scope(expression, bfs=True, prune=None): 822 """ 823 Returns a generator object which visits all nodes in the syntrax tree, stopping at 824 nodes that start child scopes. 825 826 Args: 827 expression (exp.Expression): 828 bfs (bool): if set to True the BFS traversal order will be applied, 829 otherwise the DFS traversal will be used instead. 830 prune ((node, parent, arg_key) -> bool): callable that returns True if 831 the generator should stop traversing this branch of the tree. 832 833 Yields: 834 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 835 """ 836 # We'll use this variable to pass state into the dfs generator. 837 # Whenever we set it to True, we exclude a subtree from traversal. 838 crossed_scope_boundary = False 839 840 for node in expression.walk( 841 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 842 ): 843 crossed_scope_boundary = False 844 845 yield node 846 847 if node is expression: 848 continue 849 if ( 850 isinstance(node, exp.CTE) 851 or ( 852 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 853 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 854 ) 855 or isinstance(node, exp.UNWRAPPED_QUERIES) 856 ): 857 crossed_scope_boundary = True 858 859 if isinstance(node, (exp.Subquery, exp.UDTF)): 860 # The following args are not actually in the inner scope, so we should visit them 861 for key in ("joins", "laterals", "pivots"): 862 for arg in node.args.get(key) or []: 863 yield from walk_in_scope(arg, bfs=bfs) 864 865 866def find_all_in_scope(expression, expression_types, bfs=True): 867 """ 868 Returns a generator object which visits all nodes in this scope and only yields those that 869 match at least one of the specified expression types. 870 871 This does NOT traverse into subscopes. 872 873 Args: 874 expression (exp.Expression): 875 expression_types (tuple[type]|type): the expression type(s) to match. 876 bfs (bool): True to use breadth-first search, False to use depth-first. 877 878 Yields: 879 exp.Expression: nodes 880 """ 881 for expression in walk_in_scope(expression, bfs=bfs): 882 if isinstance(expression, tuple(ensure_collection(expression_types))): 883 yield expression 884 885 886def find_in_scope(expression, expression_types, bfs=True): 887 """ 888 Returns the first node in this scope which matches at least one of the specified types. 889 890 This does NOT traverse into subscopes. 891 892 Args: 893 expression (exp.Expression): 894 expression_types (tuple[type]|type): the expression type(s) to match. 895 bfs (bool): True to use breadth-first search, False to use depth-first. 896 897 Returns: 898 exp.Expression: the node which matches the criteria or None if no node matching 899 the criteria was found. 900 """ 901 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 self._semi_anti_join_tables = None 105 106 def branch( 107 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 108 ): 109 """Branch from the current scope to a new, inner scope""" 110 return Scope( 111 expression=expression.unnest(), 112 sources=sources.copy() if sources else None, 113 parent=self, 114 scope_type=scope_type, 115 cte_sources={**self.cte_sources, **(cte_sources or {})}, 116 lateral_sources=lateral_sources.copy() if lateral_sources else None, 117 can_be_correlated=self.can_be_correlated 118 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 119 **kwargs, 120 ) 121 122 def _collect(self): 123 self._tables = [] 124 self._ctes = [] 125 self._subqueries = [] 126 self._derived_tables = [] 127 self._udtfs = [] 128 self._raw_columns = [] 129 self._stars = [] 130 self._join_hints = [] 131 self._semi_anti_join_tables = set() 132 133 for node in self.walk(bfs=False): 134 if node is self.expression: 135 continue 136 137 if isinstance(node, exp.Dot) and node.is_star: 138 self._stars.append(node) 139 elif isinstance(node, exp.Column): 140 if isinstance(node.this, exp.Star): 141 self._stars.append(node) 142 else: 143 self._raw_columns.append(node) 144 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 145 parent = node.parent 146 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 147 self._semi_anti_join_tables.add(node.alias_or_name) 148 149 self._tables.append(node) 150 elif isinstance(node, exp.JoinHint): 151 self._join_hints.append(node) 152 elif isinstance(node, exp.UDTF): 153 self._udtfs.append(node) 154 elif isinstance(node, exp.CTE): 155 self._ctes.append(node) 156 elif _is_derived_table(node) and _is_from_or_join(node): 157 self._derived_tables.append(node) 158 elif isinstance(node, exp.UNWRAPPED_QUERIES): 159 self._subqueries.append(node) 160 161 self._collected = True 162 163 def _ensure_collected(self): 164 if not self._collected: 165 self._collect() 166 167 def walk(self, bfs=True, prune=None): 168 return walk_in_scope(self.expression, bfs=bfs, prune=None) 169 170 def find(self, *expression_types, bfs=True): 171 return find_in_scope(self.expression, expression_types, bfs=bfs) 172 173 def find_all(self, *expression_types, bfs=True): 174 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 175 176 def replace(self, old, new): 177 """ 178 Replace `old` with `new`. 179 180 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 181 182 Args: 183 old (exp.Expression): old node 184 new (exp.Expression): new node 185 """ 186 old.replace(new) 187 self.clear_cache() 188 189 @property 190 def tables(self): 191 """ 192 List of tables in this scope. 193 194 Returns: 195 list[exp.Table]: tables 196 """ 197 self._ensure_collected() 198 return self._tables 199 200 @property 201 def ctes(self): 202 """ 203 List of CTEs in this scope. 204 205 Returns: 206 list[exp.CTE]: ctes 207 """ 208 self._ensure_collected() 209 return self._ctes 210 211 @property 212 def derived_tables(self): 213 """ 214 List of derived tables in this scope. 215 216 For example: 217 SELECT * FROM (SELECT ...) <- that's a derived table 218 219 Returns: 220 list[exp.Subquery]: derived tables 221 """ 222 self._ensure_collected() 223 return self._derived_tables 224 225 @property 226 def udtfs(self): 227 """ 228 List of "User Defined Tabular Functions" in this scope. 229 230 Returns: 231 list[exp.UDTF]: UDTFs 232 """ 233 self._ensure_collected() 234 return self._udtfs 235 236 @property 237 def subqueries(self): 238 """ 239 List of subqueries in this scope. 240 241 For example: 242 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 243 244 Returns: 245 list[exp.Select | exp.SetOperation]: subqueries 246 """ 247 self._ensure_collected() 248 return self._subqueries 249 250 @property 251 def stars(self) -> t.List[exp.Column | exp.Dot]: 252 """ 253 List of star expressions (columns or dots) in this scope. 254 """ 255 self._ensure_collected() 256 return self._stars 257 258 @property 259 def columns(self): 260 """ 261 List of columns in this scope. 262 263 Returns: 264 list[exp.Column]: Column instances in this scope, plus any 265 Columns that reference this scope from correlated subqueries. 266 """ 267 if self._columns is None: 268 self._ensure_collected() 269 columns = self._raw_columns 270 271 external_columns = [ 272 column 273 for scope in itertools.chain( 274 self.subquery_scopes, 275 self.udtf_scopes, 276 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 277 ) 278 for column in scope.external_columns 279 ] 280 281 named_selects = set(self.expression.named_selects) 282 283 self._columns = [] 284 for column in columns + external_columns: 285 ancestor = column.find_ancestor( 286 exp.Select, 287 exp.Qualify, 288 exp.Order, 289 exp.Having, 290 exp.Hint, 291 exp.Table, 292 exp.Star, 293 exp.Distinct, 294 ) 295 if ( 296 not ancestor 297 or column.table 298 or isinstance(ancestor, exp.Select) 299 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 300 or ( 301 isinstance(ancestor, (exp.Order, exp.Distinct)) 302 and ( 303 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 304 or column.name not in named_selects 305 ) 306 ) 307 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 308 ): 309 self._columns.append(column) 310 311 return self._columns 312 313 @property 314 def selected_sources(self): 315 """ 316 Mapping of nodes and sources that are actually selected from in this scope. 317 318 That is, all tables in a schema are selectable at any point. But a 319 table only becomes a selected source if it's included in a FROM or JOIN clause. 320 321 Returns: 322 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 323 """ 324 if self._selected_sources is None: 325 result = {} 326 327 for name, node in self.references: 328 if name in self._semi_anti_join_tables: 329 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 330 # selected source 331 continue 332 333 if name in result: 334 raise OptimizeError(f"Alias already used: {name}") 335 if name in self.sources: 336 result[name] = (node, self.sources[name]) 337 338 self._selected_sources = result 339 return self._selected_sources 340 341 @property 342 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 343 if self._references is None: 344 self._references = [] 345 346 for table in self.tables: 347 self._references.append((table.alias_or_name, table)) 348 for expression in itertools.chain(self.derived_tables, self.udtfs): 349 self._references.append( 350 ( 351 expression.alias, 352 expression if expression.args.get("pivots") else expression.unnest(), 353 ) 354 ) 355 356 return self._references 357 358 @property 359 def external_columns(self): 360 """ 361 Columns that appear to reference sources in outer scopes. 362 363 Returns: 364 list[exp.Column]: Column instances that don't reference 365 sources in the current scope. 366 """ 367 if self._external_columns is None: 368 if isinstance(self.expression, exp.SetOperation): 369 left, right = self.union_scopes 370 self._external_columns = left.external_columns + right.external_columns 371 else: 372 self._external_columns = [ 373 c 374 for c in self.columns 375 if c.table not in self.selected_sources 376 and c.table not in self.semi_or_anti_join_tables 377 ] 378 379 return self._external_columns 380 381 @property 382 def unqualified_columns(self): 383 """ 384 Unqualified columns in the current scope. 385 386 Returns: 387 list[exp.Column]: Unqualified columns 388 """ 389 return [c for c in self.columns if not c.table] 390 391 @property 392 def join_hints(self): 393 """ 394 Hints that exist in the scope that reference tables 395 396 Returns: 397 list[exp.JoinHint]: Join hints that are referenced within the scope 398 """ 399 if self._join_hints is None: 400 return [] 401 return self._join_hints 402 403 @property 404 def pivots(self): 405 if not self._pivots: 406 self._pivots = [ 407 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 408 ] 409 410 return self._pivots 411 412 @property 413 def semi_or_anti_join_tables(self): 414 return self._semi_anti_join_tables or set() 415 416 def source_columns(self, source_name): 417 """ 418 Get all columns in the current scope for a particular source. 419 420 Args: 421 source_name (str): Name of the source 422 Returns: 423 list[exp.Column]: Column instances that reference `source_name` 424 """ 425 return [column for column in self.columns if column.table == source_name] 426 427 @property 428 def is_subquery(self): 429 """Determine if this scope is a subquery""" 430 return self.scope_type == ScopeType.SUBQUERY 431 432 @property 433 def is_derived_table(self): 434 """Determine if this scope is a derived table""" 435 return self.scope_type == ScopeType.DERIVED_TABLE 436 437 @property 438 def is_union(self): 439 """Determine if this scope is a union""" 440 return self.scope_type == ScopeType.UNION 441 442 @property 443 def is_cte(self): 444 """Determine if this scope is a common table expression""" 445 return self.scope_type == ScopeType.CTE 446 447 @property 448 def is_root(self): 449 """Determine if this is the root scope""" 450 return self.scope_type == ScopeType.ROOT 451 452 @property 453 def is_udtf(self): 454 """Determine if this scope is a UDTF (User Defined Table Function)""" 455 return self.scope_type == ScopeType.UDTF 456 457 @property 458 def is_correlated_subquery(self): 459 """Determine if this scope is a correlated subquery""" 460 return bool(self.can_be_correlated and self.external_columns) 461 462 def rename_source(self, old_name, new_name): 463 """Rename a source in this scope""" 464 columns = self.sources.pop(old_name or "", []) 465 self.sources[new_name] = columns 466 467 def add_source(self, name, source): 468 """Add a source to this scope""" 469 self.sources[name] = source 470 self.clear_cache() 471 472 def remove_source(self, name): 473 """Remove a source from this scope""" 474 self.sources.pop(name, None) 475 self.clear_cache() 476 477 def __repr__(self): 478 return f"Scope<{self.expression.sql()}>" 479 480 def traverse(self): 481 """ 482 Traverse the scope tree from this node. 483 484 Yields: 485 Scope: scope instances in depth-first-search post-order 486 """ 487 stack = [self] 488 result = [] 489 while stack: 490 scope = stack.pop() 491 result.append(scope) 492 stack.extend( 493 itertools.chain( 494 scope.cte_scopes, 495 scope.union_scopes, 496 scope.table_scopes, 497 scope.subquery_scopes, 498 ) 499 ) 500 501 yield from reversed(result) 502 503 def ref_count(self): 504 """ 505 Count the number of times each scope in this tree is referenced. 506 507 Returns: 508 dict[int, int]: Mapping of Scope instance ID to reference count 509 """ 510 scope_ref_count = defaultdict(lambda: 0) 511 512 for scope in self.traverse(): 513 for _, source in scope.selected_sources.values(): 514 scope_ref_count[id(source)] += 1 515 516 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 self._semi_anti_join_tables = None
106 def branch( 107 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 108 ): 109 """Branch from the current scope to a new, inner scope""" 110 return Scope( 111 expression=expression.unnest(), 112 sources=sources.copy() if sources else None, 113 parent=self, 114 scope_type=scope_type, 115 cte_sources={**self.cte_sources, **(cte_sources or {})}, 116 lateral_sources=lateral_sources.copy() if lateral_sources else None, 117 can_be_correlated=self.can_be_correlated 118 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 119 **kwargs, 120 )
Branch from the current scope to a new, inner scope
176 def replace(self, old, new): 177 """ 178 Replace `old` with `new`. 179 180 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 181 182 Args: 183 old (exp.Expression): old node 184 new (exp.Expression): new node 185 """ 186 old.replace(new) 187 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
189 @property 190 def tables(self): 191 """ 192 List of tables in this scope. 193 194 Returns: 195 list[exp.Table]: tables 196 """ 197 self._ensure_collected() 198 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
200 @property 201 def ctes(self): 202 """ 203 List of CTEs in this scope. 204 205 Returns: 206 list[exp.CTE]: ctes 207 """ 208 self._ensure_collected() 209 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
211 @property 212 def derived_tables(self): 213 """ 214 List of derived tables in this scope. 215 216 For example: 217 SELECT * FROM (SELECT ...) <- that's a derived table 218 219 Returns: 220 list[exp.Subquery]: derived tables 221 """ 222 self._ensure_collected() 223 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
225 @property 226 def udtfs(self): 227 """ 228 List of "User Defined Tabular Functions" in this scope. 229 230 Returns: 231 list[exp.UDTF]: UDTFs 232 """ 233 self._ensure_collected() 234 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
236 @property 237 def subqueries(self): 238 """ 239 List of subqueries in this scope. 240 241 For example: 242 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 243 244 Returns: 245 list[exp.Select | exp.SetOperation]: subqueries 246 """ 247 self._ensure_collected() 248 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
250 @property 251 def stars(self) -> t.List[exp.Column | exp.Dot]: 252 """ 253 List of star expressions (columns or dots) in this scope. 254 """ 255 self._ensure_collected() 256 return self._stars
List of star expressions (columns or dots) in this scope.
258 @property 259 def columns(self): 260 """ 261 List of columns in this scope. 262 263 Returns: 264 list[exp.Column]: Column instances in this scope, plus any 265 Columns that reference this scope from correlated subqueries. 266 """ 267 if self._columns is None: 268 self._ensure_collected() 269 columns = self._raw_columns 270 271 external_columns = [ 272 column 273 for scope in itertools.chain( 274 self.subquery_scopes, 275 self.udtf_scopes, 276 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 277 ) 278 for column in scope.external_columns 279 ] 280 281 named_selects = set(self.expression.named_selects) 282 283 self._columns = [] 284 for column in columns + external_columns: 285 ancestor = column.find_ancestor( 286 exp.Select, 287 exp.Qualify, 288 exp.Order, 289 exp.Having, 290 exp.Hint, 291 exp.Table, 292 exp.Star, 293 exp.Distinct, 294 ) 295 if ( 296 not ancestor 297 or column.table 298 or isinstance(ancestor, exp.Select) 299 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 300 or ( 301 isinstance(ancestor, (exp.Order, exp.Distinct)) 302 and ( 303 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 304 or column.name not in named_selects 305 ) 306 ) 307 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 308 ): 309 self._columns.append(column) 310 311 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
313 @property 314 def selected_sources(self): 315 """ 316 Mapping of nodes and sources that are actually selected from in this scope. 317 318 That is, all tables in a schema are selectable at any point. But a 319 table only becomes a selected source if it's included in a FROM or JOIN clause. 320 321 Returns: 322 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 323 """ 324 if self._selected_sources is None: 325 result = {} 326 327 for name, node in self.references: 328 if name in self._semi_anti_join_tables: 329 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 330 # selected source 331 continue 332 333 if name in result: 334 raise OptimizeError(f"Alias already used: {name}") 335 if name in self.sources: 336 result[name] = (node, self.sources[name]) 337 338 self._selected_sources = result 339 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
341 @property 342 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 343 if self._references is None: 344 self._references = [] 345 346 for table in self.tables: 347 self._references.append((table.alias_or_name, table)) 348 for expression in itertools.chain(self.derived_tables, self.udtfs): 349 self._references.append( 350 ( 351 expression.alias, 352 expression if expression.args.get("pivots") else expression.unnest(), 353 ) 354 ) 355 356 return self._references
358 @property 359 def external_columns(self): 360 """ 361 Columns that appear to reference sources in outer scopes. 362 363 Returns: 364 list[exp.Column]: Column instances that don't reference 365 sources in the current scope. 366 """ 367 if self._external_columns is None: 368 if isinstance(self.expression, exp.SetOperation): 369 left, right = self.union_scopes 370 self._external_columns = left.external_columns + right.external_columns 371 else: 372 self._external_columns = [ 373 c 374 for c in self.columns 375 if c.table not in self.selected_sources 376 and c.table not in self.semi_or_anti_join_tables 377 ] 378 379 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
381 @property 382 def unqualified_columns(self): 383 """ 384 Unqualified columns in the current scope. 385 386 Returns: 387 list[exp.Column]: Unqualified columns 388 """ 389 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
391 @property 392 def join_hints(self): 393 """ 394 Hints that exist in the scope that reference tables 395 396 Returns: 397 list[exp.JoinHint]: Join hints that are referenced within the scope 398 """ 399 if self._join_hints is None: 400 return [] 401 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
416 def source_columns(self, source_name): 417 """ 418 Get all columns in the current scope for a particular source. 419 420 Args: 421 source_name (str): Name of the source 422 Returns: 423 list[exp.Column]: Column instances that reference `source_name` 424 """ 425 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
427 @property 428 def is_subquery(self): 429 """Determine if this scope is a subquery""" 430 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
432 @property 433 def is_derived_table(self): 434 """Determine if this scope is a derived table""" 435 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
437 @property 438 def is_union(self): 439 """Determine if this scope is a union""" 440 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
442 @property 443 def is_cte(self): 444 """Determine if this scope is a common table expression""" 445 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
447 @property 448 def is_root(self): 449 """Determine if this is the root scope""" 450 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
452 @property 453 def is_udtf(self): 454 """Determine if this scope is a UDTF (User Defined Table Function)""" 455 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
462 def rename_source(self, old_name, new_name): 463 """Rename a source in this scope""" 464 columns = self.sources.pop(old_name or "", []) 465 self.sources[new_name] = columns
Rename a source in this scope
467 def add_source(self, name, source): 468 """Add a source to this scope""" 469 self.sources[name] = source 470 self.clear_cache()
Add a source to this scope
472 def remove_source(self, name): 473 """Remove a source from this scope""" 474 self.sources.pop(name, None) 475 self.clear_cache()
Remove a source from this scope
480 def traverse(self): 481 """ 482 Traverse the scope tree from this node. 483 484 Yields: 485 Scope: scope instances in depth-first-search post-order 486 """ 487 stack = [self] 488 result = [] 489 while stack: 490 scope = stack.pop() 491 result.append(scope) 492 stack.extend( 493 itertools.chain( 494 scope.cte_scopes, 495 scope.union_scopes, 496 scope.table_scopes, 497 scope.subquery_scopes, 498 ) 499 ) 500 501 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
503 def ref_count(self): 504 """ 505 Count the number of times each scope in this tree is referenced. 506 507 Returns: 508 dict[int, int]: Mapping of Scope instance ID to reference count 509 """ 510 scope_ref_count = defaultdict(lambda: 0) 511 512 for scope in self.traverse(): 513 for _, source in scope.selected_sources.values(): 514 scope_ref_count[id(source)] += 1 515 516 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
519def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 520 """ 521 Traverse an expression by its "scopes". 522 523 "Scope" represents the current context of a Select statement. 524 525 This is helpful for optimizing queries, where we need more information than 526 the expression tree itself. For example, we might care about the source 527 names within a subquery. Returns a list because a generator could result in 528 incomplete properties which is confusing. 529 530 Examples: 531 >>> import sqlglot 532 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 533 >>> scopes = traverse_scope(expression) 534 >>> scopes[0].expression.sql(), list(scopes[0].sources) 535 ('SELECT a FROM x', ['x']) 536 >>> scopes[1].expression.sql(), list(scopes[1].sources) 537 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 538 539 Args: 540 expression: Expression to traverse 541 542 Returns: 543 A list of the created scope instances 544 """ 545 if isinstance(expression, TRAVERSABLES): 546 return list(_traverse_scope(Scope(expression))) 547 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
550def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 551 """ 552 Build a scope tree. 553 554 Args: 555 expression: Expression to build the scope tree for. 556 557 Returns: 558 The root scope 559 """ 560 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
822def walk_in_scope(expression, bfs=True, prune=None): 823 """ 824 Returns a generator object which visits all nodes in the syntrax tree, stopping at 825 nodes that start child scopes. 826 827 Args: 828 expression (exp.Expression): 829 bfs (bool): if set to True the BFS traversal order will be applied, 830 otherwise the DFS traversal will be used instead. 831 prune ((node, parent, arg_key) -> bool): callable that returns True if 832 the generator should stop traversing this branch of the tree. 833 834 Yields: 835 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 836 """ 837 # We'll use this variable to pass state into the dfs generator. 838 # Whenever we set it to True, we exclude a subtree from traversal. 839 crossed_scope_boundary = False 840 841 for node in expression.walk( 842 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 843 ): 844 crossed_scope_boundary = False 845 846 yield node 847 848 if node is expression: 849 continue 850 if ( 851 isinstance(node, exp.CTE) 852 or ( 853 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 854 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 855 ) 856 or isinstance(node, exp.UNWRAPPED_QUERIES) 857 ): 858 crossed_scope_boundary = True 859 860 if isinstance(node, (exp.Subquery, exp.UDTF)): 861 # The following args are not actually in the inner scope, so we should visit them 862 for key in ("joins", "laterals", "pivots"): 863 for arg in node.args.get(key) or []: 864 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
867def find_all_in_scope(expression, expression_types, bfs=True): 868 """ 869 Returns a generator object which visits all nodes in this scope and only yields those that 870 match at least one of the specified expression types. 871 872 This does NOT traverse into subscopes. 873 874 Args: 875 expression (exp.Expression): 876 expression_types (tuple[type]|type): the expression type(s) to match. 877 bfs (bool): True to use breadth-first search, False to use depth-first. 878 879 Yields: 880 exp.Expression: nodes 881 """ 882 for expression in walk_in_scope(expression, bfs=bfs): 883 if isinstance(expression, tuple(ensure_collection(expression_types))): 884 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
887def find_in_scope(expression, expression_types, bfs=True): 888 """ 889 Returns the first node in this scope which matches at least one of the specified types. 890 891 This does NOT traverse into subscopes. 892 893 Args: 894 expression (exp.Expression): 895 expression_types (tuple[type]|type): the expression type(s) to match. 896 bfs (bool): True to use breadth-first search, False to use depth-first. 897 898 Returns: 899 exp.Expression: the node which matches the criteria or None if no node matching 900 the criteria was found. 901 """ 902 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.