sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._table_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 self._semi_anti_join_tables = None 105 106 def branch( 107 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 108 ): 109 """Branch from the current scope to a new, inner scope""" 110 return Scope( 111 expression=expression.unnest(), 112 sources=sources.copy() if sources else None, 113 parent=self, 114 scope_type=scope_type, 115 cte_sources={**self.cte_sources, **(cte_sources or {})}, 116 lateral_sources=lateral_sources.copy() if lateral_sources else None, 117 can_be_correlated=self.can_be_correlated 118 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 119 **kwargs, 120 ) 121 122 def _collect(self): 123 self._tables = [] 124 self._ctes = [] 125 self._subqueries = [] 126 self._derived_tables = [] 127 self._udtfs = [] 128 self._raw_columns = [] 129 self._table_columns = [] 130 self._stars = [] 131 self._join_hints = [] 132 self._semi_anti_join_tables = set() 133 134 for node in self.walk(bfs=False): 135 if node is self.expression: 136 continue 137 138 if isinstance(node, exp.Dot) and node.is_star: 139 self._stars.append(node) 140 elif isinstance(node, exp.Column): 141 if isinstance(node.this, exp.Star): 142 self._stars.append(node) 143 else: 144 self._raw_columns.append(node) 145 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 146 parent = node.parent 147 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 148 self._semi_anti_join_tables.add(node.alias_or_name) 149 150 self._tables.append(node) 151 elif isinstance(node, exp.JoinHint): 152 self._join_hints.append(node) 153 elif isinstance(node, exp.UDTF): 154 self._udtfs.append(node) 155 elif isinstance(node, exp.CTE): 156 self._ctes.append(node) 157 elif _is_derived_table(node) and _is_from_or_join(node): 158 self._derived_tables.append(node) 159 elif isinstance(node, exp.UNWRAPPED_QUERIES): 160 self._subqueries.append(node) 161 elif isinstance(node, exp.TableColumn): 162 self._table_columns.append(node) 163 164 self._collected = True 165 166 def _ensure_collected(self): 167 if not self._collected: 168 self._collect() 169 170 def walk(self, bfs=True, prune=None): 171 return walk_in_scope(self.expression, bfs=bfs, prune=None) 172 173 def find(self, *expression_types, bfs=True): 174 return find_in_scope(self.expression, expression_types, bfs=bfs) 175 176 def find_all(self, *expression_types, bfs=True): 177 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 178 179 def replace(self, old, new): 180 """ 181 Replace `old` with `new`. 182 183 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 184 185 Args: 186 old (exp.Expression): old node 187 new (exp.Expression): new node 188 """ 189 old.replace(new) 190 self.clear_cache() 191 192 @property 193 def tables(self): 194 """ 195 List of tables in this scope. 196 197 Returns: 198 list[exp.Table]: tables 199 """ 200 self._ensure_collected() 201 return self._tables 202 203 @property 204 def ctes(self): 205 """ 206 List of CTEs in this scope. 207 208 Returns: 209 list[exp.CTE]: ctes 210 """ 211 self._ensure_collected() 212 return self._ctes 213 214 @property 215 def derived_tables(self): 216 """ 217 List of derived tables in this scope. 218 219 For example: 220 SELECT * FROM (SELECT ...) <- that's a derived table 221 222 Returns: 223 list[exp.Subquery]: derived tables 224 """ 225 self._ensure_collected() 226 return self._derived_tables 227 228 @property 229 def udtfs(self): 230 """ 231 List of "User Defined Tabular Functions" in this scope. 232 233 Returns: 234 list[exp.UDTF]: UDTFs 235 """ 236 self._ensure_collected() 237 return self._udtfs 238 239 @property 240 def subqueries(self): 241 """ 242 List of subqueries in this scope. 243 244 For example: 245 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 246 247 Returns: 248 list[exp.Select | exp.SetOperation]: subqueries 249 """ 250 self._ensure_collected() 251 return self._subqueries 252 253 @property 254 def stars(self) -> t.List[exp.Column | exp.Dot]: 255 """ 256 List of star expressions (columns or dots) in this scope. 257 """ 258 self._ensure_collected() 259 return self._stars 260 261 @property 262 def columns(self): 263 """ 264 List of columns in this scope. 265 266 Returns: 267 list[exp.Column]: Column instances in this scope, plus any 268 Columns that reference this scope from correlated subqueries. 269 """ 270 if self._columns is None: 271 self._ensure_collected() 272 columns = self._raw_columns 273 274 external_columns = [ 275 column 276 for scope in itertools.chain( 277 self.subquery_scopes, 278 self.udtf_scopes, 279 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 280 ) 281 for column in scope.external_columns 282 ] 283 284 named_selects = set(self.expression.named_selects) 285 286 self._columns = [] 287 for column in columns + external_columns: 288 ancestor = column.find_ancestor( 289 exp.Select, 290 exp.Qualify, 291 exp.Order, 292 exp.Having, 293 exp.Hint, 294 exp.Table, 295 exp.Star, 296 exp.Distinct, 297 ) 298 if ( 299 not ancestor 300 or column.table 301 or isinstance(ancestor, exp.Select) 302 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 303 or ( 304 isinstance(ancestor, (exp.Order, exp.Distinct)) 305 and ( 306 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 307 or column.name not in named_selects 308 ) 309 ) 310 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 311 ): 312 self._columns.append(column) 313 314 return self._columns 315 316 @property 317 def table_columns(self): 318 if self._table_columns is None: 319 self._ensure_collected() 320 321 return self._table_columns 322 323 @property 324 def selected_sources(self): 325 """ 326 Mapping of nodes and sources that are actually selected from in this scope. 327 328 That is, all tables in a schema are selectable at any point. But a 329 table only becomes a selected source if it's included in a FROM or JOIN clause. 330 331 Returns: 332 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 333 """ 334 if self._selected_sources is None: 335 result = {} 336 337 for name, node in self.references: 338 if name in self._semi_anti_join_tables: 339 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 340 # selected source 341 continue 342 343 if name in result: 344 raise OptimizeError(f"Alias already used: {name}") 345 if name in self.sources: 346 result[name] = (node, self.sources[name]) 347 348 self._selected_sources = result 349 return self._selected_sources 350 351 @property 352 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 353 if self._references is None: 354 self._references = [] 355 356 for table in self.tables: 357 self._references.append((table.alias_or_name, table)) 358 for expression in itertools.chain(self.derived_tables, self.udtfs): 359 self._references.append( 360 ( 361 expression.alias, 362 expression if expression.args.get("pivots") else expression.unnest(), 363 ) 364 ) 365 366 return self._references 367 368 @property 369 def external_columns(self): 370 """ 371 Columns that appear to reference sources in outer scopes. 372 373 Returns: 374 list[exp.Column]: Column instances that don't reference 375 sources in the current scope. 376 """ 377 if self._external_columns is None: 378 if isinstance(self.expression, exp.SetOperation): 379 left, right = self.union_scopes 380 self._external_columns = left.external_columns + right.external_columns 381 else: 382 self._external_columns = [ 383 c 384 for c in self.columns 385 if c.table not in self.selected_sources 386 and c.table not in self.semi_or_anti_join_tables 387 ] 388 389 return self._external_columns 390 391 @property 392 def unqualified_columns(self): 393 """ 394 Unqualified columns in the current scope. 395 396 Returns: 397 list[exp.Column]: Unqualified columns 398 """ 399 return [c for c in self.columns if not c.table] 400 401 @property 402 def join_hints(self): 403 """ 404 Hints that exist in the scope that reference tables 405 406 Returns: 407 list[exp.JoinHint]: Join hints that are referenced within the scope 408 """ 409 if self._join_hints is None: 410 return [] 411 return self._join_hints 412 413 @property 414 def pivots(self): 415 if not self._pivots: 416 self._pivots = [ 417 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 418 ] 419 420 return self._pivots 421 422 @property 423 def semi_or_anti_join_tables(self): 424 return self._semi_anti_join_tables or set() 425 426 def source_columns(self, source_name): 427 """ 428 Get all columns in the current scope for a particular source. 429 430 Args: 431 source_name (str): Name of the source 432 Returns: 433 list[exp.Column]: Column instances that reference `source_name` 434 """ 435 return [column for column in self.columns if column.table == source_name] 436 437 @property 438 def is_subquery(self): 439 """Determine if this scope is a subquery""" 440 return self.scope_type == ScopeType.SUBQUERY 441 442 @property 443 def is_derived_table(self): 444 """Determine if this scope is a derived table""" 445 return self.scope_type == ScopeType.DERIVED_TABLE 446 447 @property 448 def is_union(self): 449 """Determine if this scope is a union""" 450 return self.scope_type == ScopeType.UNION 451 452 @property 453 def is_cte(self): 454 """Determine if this scope is a common table expression""" 455 return self.scope_type == ScopeType.CTE 456 457 @property 458 def is_root(self): 459 """Determine if this is the root scope""" 460 return self.scope_type == ScopeType.ROOT 461 462 @property 463 def is_udtf(self): 464 """Determine if this scope is a UDTF (User Defined Table Function)""" 465 return self.scope_type == ScopeType.UDTF 466 467 @property 468 def is_correlated_subquery(self): 469 """Determine if this scope is a correlated subquery""" 470 return bool(self.can_be_correlated and self.external_columns) 471 472 def rename_source(self, old_name, new_name): 473 """Rename a source in this scope""" 474 columns = self.sources.pop(old_name or "", []) 475 self.sources[new_name] = columns 476 477 def add_source(self, name, source): 478 """Add a source to this scope""" 479 self.sources[name] = source 480 self.clear_cache() 481 482 def remove_source(self, name): 483 """Remove a source from this scope""" 484 self.sources.pop(name, None) 485 self.clear_cache() 486 487 def __repr__(self): 488 return f"Scope<{self.expression.sql()}>" 489 490 def traverse(self): 491 """ 492 Traverse the scope tree from this node. 493 494 Yields: 495 Scope: scope instances in depth-first-search post-order 496 """ 497 stack = [self] 498 result = [] 499 while stack: 500 scope = stack.pop() 501 result.append(scope) 502 stack.extend( 503 itertools.chain( 504 scope.cte_scopes, 505 scope.union_scopes, 506 scope.table_scopes, 507 scope.subquery_scopes, 508 ) 509 ) 510 511 yield from reversed(result) 512 513 def ref_count(self): 514 """ 515 Count the number of times each scope in this tree is referenced. 516 517 Returns: 518 dict[int, int]: Mapping of Scope instance ID to reference count 519 """ 520 scope_ref_count = defaultdict(lambda: 0) 521 522 for scope in self.traverse(): 523 for _, source in scope.selected_sources.values(): 524 scope_ref_count[id(source)] += 1 525 526 return scope_ref_count 527 528 529def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 530 """ 531 Traverse an expression by its "scopes". 532 533 "Scope" represents the current context of a Select statement. 534 535 This is helpful for optimizing queries, where we need more information than 536 the expression tree itself. For example, we might care about the source 537 names within a subquery. Returns a list because a generator could result in 538 incomplete properties which is confusing. 539 540 Examples: 541 >>> import sqlglot 542 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 543 >>> scopes = traverse_scope(expression) 544 >>> scopes[0].expression.sql(), list(scopes[0].sources) 545 ('SELECT a FROM x', ['x']) 546 >>> scopes[1].expression.sql(), list(scopes[1].sources) 547 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 548 549 Args: 550 expression: Expression to traverse 551 552 Returns: 553 A list of the created scope instances 554 """ 555 if isinstance(expression, TRAVERSABLES): 556 return list(_traverse_scope(Scope(expression))) 557 return [] 558 559 560def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 561 """ 562 Build a scope tree. 563 564 Args: 565 expression: Expression to build the scope tree for. 566 567 Returns: 568 The root scope 569 """ 570 return seq_get(traverse_scope(expression), -1) 571 572 573def _traverse_scope(scope): 574 expression = scope.expression 575 576 if isinstance(expression, exp.Select): 577 yield from _traverse_select(scope) 578 elif isinstance(expression, exp.SetOperation): 579 yield from _traverse_ctes(scope) 580 yield from _traverse_union(scope) 581 return 582 elif isinstance(expression, exp.Subquery): 583 if scope.is_root: 584 yield from _traverse_select(scope) 585 else: 586 yield from _traverse_subqueries(scope) 587 elif isinstance(expression, exp.Table): 588 yield from _traverse_tables(scope) 589 elif isinstance(expression, exp.UDTF): 590 yield from _traverse_udtfs(scope) 591 elif isinstance(expression, exp.DDL): 592 if isinstance(expression.expression, exp.Query): 593 yield from _traverse_ctes(scope) 594 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 595 return 596 elif isinstance(expression, exp.DML): 597 yield from _traverse_ctes(scope) 598 for query in find_all_in_scope(expression, exp.Query): 599 # This check ensures we don't yield the CTE/nested queries twice 600 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 601 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 602 return 603 else: 604 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 605 return 606 607 yield scope 608 609 610def _traverse_select(scope): 611 yield from _traverse_ctes(scope) 612 yield from _traverse_tables(scope) 613 yield from _traverse_subqueries(scope) 614 615 616def _traverse_union(scope): 617 prev_scope = None 618 union_scope_stack = [scope] 619 expression_stack = [scope.expression.right, scope.expression.left] 620 621 while expression_stack: 622 expression = expression_stack.pop() 623 union_scope = union_scope_stack[-1] 624 625 new_scope = union_scope.branch( 626 expression, 627 outer_columns=union_scope.outer_columns, 628 scope_type=ScopeType.UNION, 629 ) 630 631 if isinstance(expression, exp.SetOperation): 632 yield from _traverse_ctes(new_scope) 633 634 union_scope_stack.append(new_scope) 635 expression_stack.extend([expression.right, expression.left]) 636 continue 637 638 for scope in _traverse_scope(new_scope): 639 yield scope 640 641 if prev_scope: 642 union_scope_stack.pop() 643 union_scope.union_scopes = [prev_scope, scope] 644 prev_scope = union_scope 645 646 yield union_scope 647 else: 648 prev_scope = scope 649 650 651def _traverse_ctes(scope): 652 sources = {} 653 654 for cte in scope.ctes: 655 cte_name = cte.alias 656 657 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 658 # thus the recursive scope is the first section of the union. 659 with_ = scope.expression.args.get("with") 660 if with_ and with_.recursive: 661 union = cte.this 662 663 if isinstance(union, exp.SetOperation): 664 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 665 666 child_scope = None 667 668 for child_scope in _traverse_scope( 669 scope.branch( 670 cte.this, 671 cte_sources=sources, 672 outer_columns=cte.alias_column_names, 673 scope_type=ScopeType.CTE, 674 ) 675 ): 676 yield child_scope 677 678 # append the final child_scope yielded 679 if child_scope: 680 sources[cte_name] = child_scope 681 scope.cte_scopes.append(child_scope) 682 683 scope.sources.update(sources) 684 scope.cte_sources.update(sources) 685 686 687def _is_derived_table(expression: exp.Subquery) -> bool: 688 """ 689 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 690 as it doesn't introduce a new scope. If an alias is present, it shadows all names 691 under the Subquery, so that's one exception to this rule. 692 """ 693 return isinstance(expression, exp.Subquery) and bool( 694 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 695 ) 696 697 698def _is_from_or_join(expression: exp.Expression) -> bool: 699 """ 700 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 701 """ 702 parent = expression.parent 703 704 # Subqueries can be arbitrarily nested 705 while isinstance(parent, exp.Subquery): 706 parent = parent.parent 707 708 return isinstance(parent, (exp.From, exp.Join)) 709 710 711def _traverse_tables(scope): 712 sources = {} 713 714 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 715 expressions = [] 716 from_ = scope.expression.args.get("from") 717 if from_: 718 expressions.append(from_.this) 719 720 for join in scope.expression.args.get("joins") or []: 721 expressions.append(join.this) 722 723 if isinstance(scope.expression, exp.Table): 724 expressions.append(scope.expression) 725 726 expressions.extend(scope.expression.args.get("laterals") or []) 727 728 for expression in expressions: 729 if isinstance(expression, exp.Final): 730 expression = expression.this 731 if isinstance(expression, exp.Table): 732 table_name = expression.name 733 source_name = expression.alias_or_name 734 735 if table_name in scope.sources and not expression.db: 736 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 737 # it is pivoted, because then we get back a new table and hence a new source. 738 pivots = expression.args.get("pivots") 739 if pivots: 740 sources[pivots[0].alias] = expression 741 else: 742 sources[source_name] = scope.sources[table_name] 743 elif source_name in sources: 744 sources[find_new_name(sources, table_name)] = expression 745 else: 746 sources[source_name] = expression 747 748 # Make sure to not include the joins twice 749 if expression is not scope.expression: 750 expressions.extend(join.this for join in expression.args.get("joins") or []) 751 752 continue 753 754 if not isinstance(expression, exp.DerivedTable): 755 continue 756 757 if isinstance(expression, exp.UDTF): 758 lateral_sources = sources 759 scope_type = ScopeType.UDTF 760 scopes = scope.udtf_scopes 761 elif _is_derived_table(expression): 762 lateral_sources = None 763 scope_type = ScopeType.DERIVED_TABLE 764 scopes = scope.derived_table_scopes 765 expressions.extend(join.this for join in expression.args.get("joins") or []) 766 else: 767 # Makes sure we check for possible sources in nested table constructs 768 expressions.append(expression.this) 769 expressions.extend(join.this for join in expression.args.get("joins") or []) 770 continue 771 772 child_scope = None 773 774 for child_scope in _traverse_scope( 775 scope.branch( 776 expression, 777 lateral_sources=lateral_sources, 778 outer_columns=expression.alias_column_names, 779 scope_type=scope_type, 780 ) 781 ): 782 yield child_scope 783 784 # Tables without aliases will be set as "" 785 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 786 # Until then, this means that only a single, unaliased derived table is allowed (rather, 787 # the latest one wins. 788 sources[expression.alias] = child_scope 789 790 # append the final child_scope yielded 791 if child_scope: 792 scopes.append(child_scope) 793 scope.table_scopes.append(child_scope) 794 795 scope.sources.update(sources) 796 797 798def _traverse_subqueries(scope): 799 for subquery in scope.subqueries: 800 top = None 801 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 802 yield child_scope 803 top = child_scope 804 scope.subquery_scopes.append(top) 805 806 807def _traverse_udtfs(scope): 808 if isinstance(scope.expression, exp.Unnest): 809 expressions = scope.expression.expressions 810 elif isinstance(scope.expression, exp.Lateral): 811 expressions = [scope.expression.this] 812 else: 813 expressions = [] 814 815 sources = {} 816 for expression in expressions: 817 if _is_derived_table(expression): 818 top = None 819 for child_scope in _traverse_scope( 820 scope.branch( 821 expression, 822 scope_type=ScopeType.SUBQUERY, 823 outer_columns=expression.alias_column_names, 824 ) 825 ): 826 yield child_scope 827 top = child_scope 828 sources[expression.alias] = child_scope 829 830 scope.subquery_scopes.append(top) 831 832 scope.sources.update(sources) 833 834 835def walk_in_scope(expression, bfs=True, prune=None): 836 """ 837 Returns a generator object which visits all nodes in the syntrax tree, stopping at 838 nodes that start child scopes. 839 840 Args: 841 expression (exp.Expression): 842 bfs (bool): if set to True the BFS traversal order will be applied, 843 otherwise the DFS traversal will be used instead. 844 prune ((node, parent, arg_key) -> bool): callable that returns True if 845 the generator should stop traversing this branch of the tree. 846 847 Yields: 848 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 849 """ 850 # We'll use this variable to pass state into the dfs generator. 851 # Whenever we set it to True, we exclude a subtree from traversal. 852 crossed_scope_boundary = False 853 854 for node in expression.walk( 855 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 856 ): 857 crossed_scope_boundary = False 858 859 yield node 860 861 if node is expression: 862 continue 863 864 if ( 865 isinstance(node, exp.CTE) 866 or ( 867 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 868 and _is_derived_table(node) 869 ) 870 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 871 or isinstance(node, exp.UNWRAPPED_QUERIES) 872 ): 873 crossed_scope_boundary = True 874 875 if isinstance(node, (exp.Subquery, exp.UDTF)): 876 # The following args are not actually in the inner scope, so we should visit them 877 for key in ("joins", "laterals", "pivots"): 878 for arg in node.args.get(key) or []: 879 yield from walk_in_scope(arg, bfs=bfs) 880 881 882def find_all_in_scope(expression, expression_types, bfs=True): 883 """ 884 Returns a generator object which visits all nodes in this scope and only yields those that 885 match at least one of the specified expression types. 886 887 This does NOT traverse into subscopes. 888 889 Args: 890 expression (exp.Expression): 891 expression_types (tuple[type]|type): the expression type(s) to match. 892 bfs (bool): True to use breadth-first search, False to use depth-first. 893 894 Yields: 895 exp.Expression: nodes 896 """ 897 for expression in walk_in_scope(expression, bfs=bfs): 898 if isinstance(expression, tuple(ensure_collection(expression_types))): 899 yield expression 900 901 902def find_in_scope(expression, expression_types, bfs=True): 903 """ 904 Returns the first node in this scope which matches at least one of the specified types. 905 906 This does NOT traverse into subscopes. 907 908 Args: 909 expression (exp.Expression): 910 expression_types (tuple[type]|type): the expression type(s) to match. 911 bfs (bool): True to use breadth-first search, False to use depth-first. 912 913 Returns: 914 exp.Expression: the node which matches the criteria or None if no node matching 915 the criteria was found. 916 """ 917 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._join_hints = None 103 self._pivots = None 104 self._references = None 105 self._semi_anti_join_tables = None 106 107 def branch( 108 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 109 ): 110 """Branch from the current scope to a new, inner scope""" 111 return Scope( 112 expression=expression.unnest(), 113 sources=sources.copy() if sources else None, 114 parent=self, 115 scope_type=scope_type, 116 cte_sources={**self.cte_sources, **(cte_sources or {})}, 117 lateral_sources=lateral_sources.copy() if lateral_sources else None, 118 can_be_correlated=self.can_be_correlated 119 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 120 **kwargs, 121 ) 122 123 def _collect(self): 124 self._tables = [] 125 self._ctes = [] 126 self._subqueries = [] 127 self._derived_tables = [] 128 self._udtfs = [] 129 self._raw_columns = [] 130 self._table_columns = [] 131 self._stars = [] 132 self._join_hints = [] 133 self._semi_anti_join_tables = set() 134 135 for node in self.walk(bfs=False): 136 if node is self.expression: 137 continue 138 139 if isinstance(node, exp.Dot) and node.is_star: 140 self._stars.append(node) 141 elif isinstance(node, exp.Column): 142 if isinstance(node.this, exp.Star): 143 self._stars.append(node) 144 else: 145 self._raw_columns.append(node) 146 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 147 parent = node.parent 148 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 149 self._semi_anti_join_tables.add(node.alias_or_name) 150 151 self._tables.append(node) 152 elif isinstance(node, exp.JoinHint): 153 self._join_hints.append(node) 154 elif isinstance(node, exp.UDTF): 155 self._udtfs.append(node) 156 elif isinstance(node, exp.CTE): 157 self._ctes.append(node) 158 elif _is_derived_table(node) and _is_from_or_join(node): 159 self._derived_tables.append(node) 160 elif isinstance(node, exp.UNWRAPPED_QUERIES): 161 self._subqueries.append(node) 162 elif isinstance(node, exp.TableColumn): 163 self._table_columns.append(node) 164 165 self._collected = True 166 167 def _ensure_collected(self): 168 if not self._collected: 169 self._collect() 170 171 def walk(self, bfs=True, prune=None): 172 return walk_in_scope(self.expression, bfs=bfs, prune=None) 173 174 def find(self, *expression_types, bfs=True): 175 return find_in_scope(self.expression, expression_types, bfs=bfs) 176 177 def find_all(self, *expression_types, bfs=True): 178 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 179 180 def replace(self, old, new): 181 """ 182 Replace `old` with `new`. 183 184 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 185 186 Args: 187 old (exp.Expression): old node 188 new (exp.Expression): new node 189 """ 190 old.replace(new) 191 self.clear_cache() 192 193 @property 194 def tables(self): 195 """ 196 List of tables in this scope. 197 198 Returns: 199 list[exp.Table]: tables 200 """ 201 self._ensure_collected() 202 return self._tables 203 204 @property 205 def ctes(self): 206 """ 207 List of CTEs in this scope. 208 209 Returns: 210 list[exp.CTE]: ctes 211 """ 212 self._ensure_collected() 213 return self._ctes 214 215 @property 216 def derived_tables(self): 217 """ 218 List of derived tables in this scope. 219 220 For example: 221 SELECT * FROM (SELECT ...) <- that's a derived table 222 223 Returns: 224 list[exp.Subquery]: derived tables 225 """ 226 self._ensure_collected() 227 return self._derived_tables 228 229 @property 230 def udtfs(self): 231 """ 232 List of "User Defined Tabular Functions" in this scope. 233 234 Returns: 235 list[exp.UDTF]: UDTFs 236 """ 237 self._ensure_collected() 238 return self._udtfs 239 240 @property 241 def subqueries(self): 242 """ 243 List of subqueries in this scope. 244 245 For example: 246 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 247 248 Returns: 249 list[exp.Select | exp.SetOperation]: subqueries 250 """ 251 self._ensure_collected() 252 return self._subqueries 253 254 @property 255 def stars(self) -> t.List[exp.Column | exp.Dot]: 256 """ 257 List of star expressions (columns or dots) in this scope. 258 """ 259 self._ensure_collected() 260 return self._stars 261 262 @property 263 def columns(self): 264 """ 265 List of columns in this scope. 266 267 Returns: 268 list[exp.Column]: Column instances in this scope, plus any 269 Columns that reference this scope from correlated subqueries. 270 """ 271 if self._columns is None: 272 self._ensure_collected() 273 columns = self._raw_columns 274 275 external_columns = [ 276 column 277 for scope in itertools.chain( 278 self.subquery_scopes, 279 self.udtf_scopes, 280 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 281 ) 282 for column in scope.external_columns 283 ] 284 285 named_selects = set(self.expression.named_selects) 286 287 self._columns = [] 288 for column in columns + external_columns: 289 ancestor = column.find_ancestor( 290 exp.Select, 291 exp.Qualify, 292 exp.Order, 293 exp.Having, 294 exp.Hint, 295 exp.Table, 296 exp.Star, 297 exp.Distinct, 298 ) 299 if ( 300 not ancestor 301 or column.table 302 or isinstance(ancestor, exp.Select) 303 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 304 or ( 305 isinstance(ancestor, (exp.Order, exp.Distinct)) 306 and ( 307 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 308 or column.name not in named_selects 309 ) 310 ) 311 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 312 ): 313 self._columns.append(column) 314 315 return self._columns 316 317 @property 318 def table_columns(self): 319 if self._table_columns is None: 320 self._ensure_collected() 321 322 return self._table_columns 323 324 @property 325 def selected_sources(self): 326 """ 327 Mapping of nodes and sources that are actually selected from in this scope. 328 329 That is, all tables in a schema are selectable at any point. But a 330 table only becomes a selected source if it's included in a FROM or JOIN clause. 331 332 Returns: 333 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 334 """ 335 if self._selected_sources is None: 336 result = {} 337 338 for name, node in self.references: 339 if name in self._semi_anti_join_tables: 340 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 341 # selected source 342 continue 343 344 if name in result: 345 raise OptimizeError(f"Alias already used: {name}") 346 if name in self.sources: 347 result[name] = (node, self.sources[name]) 348 349 self._selected_sources = result 350 return self._selected_sources 351 352 @property 353 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 354 if self._references is None: 355 self._references = [] 356 357 for table in self.tables: 358 self._references.append((table.alias_or_name, table)) 359 for expression in itertools.chain(self.derived_tables, self.udtfs): 360 self._references.append( 361 ( 362 expression.alias, 363 expression if expression.args.get("pivots") else expression.unnest(), 364 ) 365 ) 366 367 return self._references 368 369 @property 370 def external_columns(self): 371 """ 372 Columns that appear to reference sources in outer scopes. 373 374 Returns: 375 list[exp.Column]: Column instances that don't reference 376 sources in the current scope. 377 """ 378 if self._external_columns is None: 379 if isinstance(self.expression, exp.SetOperation): 380 left, right = self.union_scopes 381 self._external_columns = left.external_columns + right.external_columns 382 else: 383 self._external_columns = [ 384 c 385 for c in self.columns 386 if c.table not in self.selected_sources 387 and c.table not in self.semi_or_anti_join_tables 388 ] 389 390 return self._external_columns 391 392 @property 393 def unqualified_columns(self): 394 """ 395 Unqualified columns in the current scope. 396 397 Returns: 398 list[exp.Column]: Unqualified columns 399 """ 400 return [c for c in self.columns if not c.table] 401 402 @property 403 def join_hints(self): 404 """ 405 Hints that exist in the scope that reference tables 406 407 Returns: 408 list[exp.JoinHint]: Join hints that are referenced within the scope 409 """ 410 if self._join_hints is None: 411 return [] 412 return self._join_hints 413 414 @property 415 def pivots(self): 416 if not self._pivots: 417 self._pivots = [ 418 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 419 ] 420 421 return self._pivots 422 423 @property 424 def semi_or_anti_join_tables(self): 425 return self._semi_anti_join_tables or set() 426 427 def source_columns(self, source_name): 428 """ 429 Get all columns in the current scope for a particular source. 430 431 Args: 432 source_name (str): Name of the source 433 Returns: 434 list[exp.Column]: Column instances that reference `source_name` 435 """ 436 return [column for column in self.columns if column.table == source_name] 437 438 @property 439 def is_subquery(self): 440 """Determine if this scope is a subquery""" 441 return self.scope_type == ScopeType.SUBQUERY 442 443 @property 444 def is_derived_table(self): 445 """Determine if this scope is a derived table""" 446 return self.scope_type == ScopeType.DERIVED_TABLE 447 448 @property 449 def is_union(self): 450 """Determine if this scope is a union""" 451 return self.scope_type == ScopeType.UNION 452 453 @property 454 def is_cte(self): 455 """Determine if this scope is a common table expression""" 456 return self.scope_type == ScopeType.CTE 457 458 @property 459 def is_root(self): 460 """Determine if this is the root scope""" 461 return self.scope_type == ScopeType.ROOT 462 463 @property 464 def is_udtf(self): 465 """Determine if this scope is a UDTF (User Defined Table Function)""" 466 return self.scope_type == ScopeType.UDTF 467 468 @property 469 def is_correlated_subquery(self): 470 """Determine if this scope is a correlated subquery""" 471 return bool(self.can_be_correlated and self.external_columns) 472 473 def rename_source(self, old_name, new_name): 474 """Rename a source in this scope""" 475 columns = self.sources.pop(old_name or "", []) 476 self.sources[new_name] = columns 477 478 def add_source(self, name, source): 479 """Add a source to this scope""" 480 self.sources[name] = source 481 self.clear_cache() 482 483 def remove_source(self, name): 484 """Remove a source from this scope""" 485 self.sources.pop(name, None) 486 self.clear_cache() 487 488 def __repr__(self): 489 return f"Scope<{self.expression.sql()}>" 490 491 def traverse(self): 492 """ 493 Traverse the scope tree from this node. 494 495 Yields: 496 Scope: scope instances in depth-first-search post-order 497 """ 498 stack = [self] 499 result = [] 500 while stack: 501 scope = stack.pop() 502 result.append(scope) 503 stack.extend( 504 itertools.chain( 505 scope.cte_scopes, 506 scope.union_scopes, 507 scope.table_scopes, 508 scope.subquery_scopes, 509 ) 510 ) 511 512 yield from reversed(result) 513 514 def ref_count(self): 515 """ 516 Count the number of times each scope in this tree is referenced. 517 518 Returns: 519 dict[int, int]: Mapping of Scope instance ID to reference count 520 """ 521 scope_ref_count = defaultdict(lambda: 0) 522 523 for scope in self.traverse(): 524 for _, source in scope.selected_sources.values(): 525 scope_ref_count[id(source)] += 1 526 527 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._join_hints = None 103 self._pivots = None 104 self._references = None 105 self._semi_anti_join_tables = None
107 def branch( 108 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 109 ): 110 """Branch from the current scope to a new, inner scope""" 111 return Scope( 112 expression=expression.unnest(), 113 sources=sources.copy() if sources else None, 114 parent=self, 115 scope_type=scope_type, 116 cte_sources={**self.cte_sources, **(cte_sources or {})}, 117 lateral_sources=lateral_sources.copy() if lateral_sources else None, 118 can_be_correlated=self.can_be_correlated 119 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 120 **kwargs, 121 )
Branch from the current scope to a new, inner scope
180 def replace(self, old, new): 181 """ 182 Replace `old` with `new`. 183 184 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 185 186 Args: 187 old (exp.Expression): old node 188 new (exp.Expression): new node 189 """ 190 old.replace(new) 191 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
193 @property 194 def tables(self): 195 """ 196 List of tables in this scope. 197 198 Returns: 199 list[exp.Table]: tables 200 """ 201 self._ensure_collected() 202 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
204 @property 205 def ctes(self): 206 """ 207 List of CTEs in this scope. 208 209 Returns: 210 list[exp.CTE]: ctes 211 """ 212 self._ensure_collected() 213 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
215 @property 216 def derived_tables(self): 217 """ 218 List of derived tables in this scope. 219 220 For example: 221 SELECT * FROM (SELECT ...) <- that's a derived table 222 223 Returns: 224 list[exp.Subquery]: derived tables 225 """ 226 self._ensure_collected() 227 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
229 @property 230 def udtfs(self): 231 """ 232 List of "User Defined Tabular Functions" in this scope. 233 234 Returns: 235 list[exp.UDTF]: UDTFs 236 """ 237 self._ensure_collected() 238 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
240 @property 241 def subqueries(self): 242 """ 243 List of subqueries in this scope. 244 245 For example: 246 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 247 248 Returns: 249 list[exp.Select | exp.SetOperation]: subqueries 250 """ 251 self._ensure_collected() 252 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
254 @property 255 def stars(self) -> t.List[exp.Column | exp.Dot]: 256 """ 257 List of star expressions (columns or dots) in this scope. 258 """ 259 self._ensure_collected() 260 return self._stars
List of star expressions (columns or dots) in this scope.
262 @property 263 def columns(self): 264 """ 265 List of columns in this scope. 266 267 Returns: 268 list[exp.Column]: Column instances in this scope, plus any 269 Columns that reference this scope from correlated subqueries. 270 """ 271 if self._columns is None: 272 self._ensure_collected() 273 columns = self._raw_columns 274 275 external_columns = [ 276 column 277 for scope in itertools.chain( 278 self.subquery_scopes, 279 self.udtf_scopes, 280 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 281 ) 282 for column in scope.external_columns 283 ] 284 285 named_selects = set(self.expression.named_selects) 286 287 self._columns = [] 288 for column in columns + external_columns: 289 ancestor = column.find_ancestor( 290 exp.Select, 291 exp.Qualify, 292 exp.Order, 293 exp.Having, 294 exp.Hint, 295 exp.Table, 296 exp.Star, 297 exp.Distinct, 298 ) 299 if ( 300 not ancestor 301 or column.table 302 or isinstance(ancestor, exp.Select) 303 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 304 or ( 305 isinstance(ancestor, (exp.Order, exp.Distinct)) 306 and ( 307 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 308 or column.name not in named_selects 309 ) 310 ) 311 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 312 ): 313 self._columns.append(column) 314 315 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
324 @property 325 def selected_sources(self): 326 """ 327 Mapping of nodes and sources that are actually selected from in this scope. 328 329 That is, all tables in a schema are selectable at any point. But a 330 table only becomes a selected source if it's included in a FROM or JOIN clause. 331 332 Returns: 333 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 334 """ 335 if self._selected_sources is None: 336 result = {} 337 338 for name, node in self.references: 339 if name in self._semi_anti_join_tables: 340 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 341 # selected source 342 continue 343 344 if name in result: 345 raise OptimizeError(f"Alias already used: {name}") 346 if name in self.sources: 347 result[name] = (node, self.sources[name]) 348 349 self._selected_sources = result 350 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
352 @property 353 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 354 if self._references is None: 355 self._references = [] 356 357 for table in self.tables: 358 self._references.append((table.alias_or_name, table)) 359 for expression in itertools.chain(self.derived_tables, self.udtfs): 360 self._references.append( 361 ( 362 expression.alias, 363 expression if expression.args.get("pivots") else expression.unnest(), 364 ) 365 ) 366 367 return self._references
369 @property 370 def external_columns(self): 371 """ 372 Columns that appear to reference sources in outer scopes. 373 374 Returns: 375 list[exp.Column]: Column instances that don't reference 376 sources in the current scope. 377 """ 378 if self._external_columns is None: 379 if isinstance(self.expression, exp.SetOperation): 380 left, right = self.union_scopes 381 self._external_columns = left.external_columns + right.external_columns 382 else: 383 self._external_columns = [ 384 c 385 for c in self.columns 386 if c.table not in self.selected_sources 387 and c.table not in self.semi_or_anti_join_tables 388 ] 389 390 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
392 @property 393 def unqualified_columns(self): 394 """ 395 Unqualified columns in the current scope. 396 397 Returns: 398 list[exp.Column]: Unqualified columns 399 """ 400 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
402 @property 403 def join_hints(self): 404 """ 405 Hints that exist in the scope that reference tables 406 407 Returns: 408 list[exp.JoinHint]: Join hints that are referenced within the scope 409 """ 410 if self._join_hints is None: 411 return [] 412 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
427 def source_columns(self, source_name): 428 """ 429 Get all columns in the current scope for a particular source. 430 431 Args: 432 source_name (str): Name of the source 433 Returns: 434 list[exp.Column]: Column instances that reference `source_name` 435 """ 436 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
438 @property 439 def is_subquery(self): 440 """Determine if this scope is a subquery""" 441 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
443 @property 444 def is_derived_table(self): 445 """Determine if this scope is a derived table""" 446 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
448 @property 449 def is_union(self): 450 """Determine if this scope is a union""" 451 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
453 @property 454 def is_cte(self): 455 """Determine if this scope is a common table expression""" 456 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
458 @property 459 def is_root(self): 460 """Determine if this is the root scope""" 461 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
463 @property 464 def is_udtf(self): 465 """Determine if this scope is a UDTF (User Defined Table Function)""" 466 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
473 def rename_source(self, old_name, new_name): 474 """Rename a source in this scope""" 475 columns = self.sources.pop(old_name or "", []) 476 self.sources[new_name] = columns
Rename a source in this scope
478 def add_source(self, name, source): 479 """Add a source to this scope""" 480 self.sources[name] = source 481 self.clear_cache()
Add a source to this scope
483 def remove_source(self, name): 484 """Remove a source from this scope""" 485 self.sources.pop(name, None) 486 self.clear_cache()
Remove a source from this scope
491 def traverse(self): 492 """ 493 Traverse the scope tree from this node. 494 495 Yields: 496 Scope: scope instances in depth-first-search post-order 497 """ 498 stack = [self] 499 result = [] 500 while stack: 501 scope = stack.pop() 502 result.append(scope) 503 stack.extend( 504 itertools.chain( 505 scope.cte_scopes, 506 scope.union_scopes, 507 scope.table_scopes, 508 scope.subquery_scopes, 509 ) 510 ) 511 512 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
514 def ref_count(self): 515 """ 516 Count the number of times each scope in this tree is referenced. 517 518 Returns: 519 dict[int, int]: Mapping of Scope instance ID to reference count 520 """ 521 scope_ref_count = defaultdict(lambda: 0) 522 523 for scope in self.traverse(): 524 for _, source in scope.selected_sources.values(): 525 scope_ref_count[id(source)] += 1 526 527 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
530def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 531 """ 532 Traverse an expression by its "scopes". 533 534 "Scope" represents the current context of a Select statement. 535 536 This is helpful for optimizing queries, where we need more information than 537 the expression tree itself. For example, we might care about the source 538 names within a subquery. Returns a list because a generator could result in 539 incomplete properties which is confusing. 540 541 Examples: 542 >>> import sqlglot 543 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 544 >>> scopes = traverse_scope(expression) 545 >>> scopes[0].expression.sql(), list(scopes[0].sources) 546 ('SELECT a FROM x', ['x']) 547 >>> scopes[1].expression.sql(), list(scopes[1].sources) 548 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 549 550 Args: 551 expression: Expression to traverse 552 553 Returns: 554 A list of the created scope instances 555 """ 556 if isinstance(expression, TRAVERSABLES): 557 return list(_traverse_scope(Scope(expression))) 558 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
561def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 562 """ 563 Build a scope tree. 564 565 Args: 566 expression: Expression to build the scope tree for. 567 568 Returns: 569 The root scope 570 """ 571 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
836def walk_in_scope(expression, bfs=True, prune=None): 837 """ 838 Returns a generator object which visits all nodes in the syntrax tree, stopping at 839 nodes that start child scopes. 840 841 Args: 842 expression (exp.Expression): 843 bfs (bool): if set to True the BFS traversal order will be applied, 844 otherwise the DFS traversal will be used instead. 845 prune ((node, parent, arg_key) -> bool): callable that returns True if 846 the generator should stop traversing this branch of the tree. 847 848 Yields: 849 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 850 """ 851 # We'll use this variable to pass state into the dfs generator. 852 # Whenever we set it to True, we exclude a subtree from traversal. 853 crossed_scope_boundary = False 854 855 for node in expression.walk( 856 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 857 ): 858 crossed_scope_boundary = False 859 860 yield node 861 862 if node is expression: 863 continue 864 865 if ( 866 isinstance(node, exp.CTE) 867 or ( 868 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 869 and _is_derived_table(node) 870 ) 871 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 872 or isinstance(node, exp.UNWRAPPED_QUERIES) 873 ): 874 crossed_scope_boundary = True 875 876 if isinstance(node, (exp.Subquery, exp.UDTF)): 877 # The following args are not actually in the inner scope, so we should visit them 878 for key in ("joins", "laterals", "pivots"): 879 for arg in node.args.get(key) or []: 880 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
883def find_all_in_scope(expression, expression_types, bfs=True): 884 """ 885 Returns a generator object which visits all nodes in this scope and only yields those that 886 match at least one of the specified expression types. 887 888 This does NOT traverse into subscopes. 889 890 Args: 891 expression (exp.Expression): 892 expression_types (tuple[type]|type): the expression type(s) to match. 893 bfs (bool): True to use breadth-first search, False to use depth-first. 894 895 Yields: 896 exp.Expression: nodes 897 """ 898 for expression in walk_in_scope(expression, bfs=bfs): 899 if isinstance(expression, tuple(ensure_collection(expression_types))): 900 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
903def find_in_scope(expression, expression_types, bfs=True): 904 """ 905 Returns the first node in this scope which matches at least one of the specified types. 906 907 This does NOT traverse into subscopes. 908 909 Args: 910 expression (exp.Expression): 911 expression_types (tuple[type]|type): the expression type(s) to match. 912 bfs (bool): True to use breadth-first search, False to use depth-first. 913 914 Returns: 915 exp.Expression: the node which matches the criteria or None if no node matching 916 the criteria was found. 917 """ 918 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.