sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._table_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._local_columns = None 102 self._join_hints = None 103 self._pivots = None 104 self._references = None 105 self._semi_anti_join_tables = None 106 self._column_index = None 107 108 def branch( 109 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 110 ): 111 """Branch from the current scope to a new, inner scope""" 112 return Scope( 113 expression=expression.unnest(), 114 sources=sources.copy() if sources else None, 115 parent=self, 116 scope_type=scope_type, 117 cte_sources={**self.cte_sources, **(cte_sources or {})}, 118 lateral_sources=lateral_sources.copy() if lateral_sources else None, 119 can_be_correlated=self.can_be_correlated 120 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 121 **kwargs, 122 ) 123 124 def _collect(self): 125 self._tables = [] 126 self._ctes = [] 127 self._subqueries = [] 128 self._derived_tables = [] 129 self._udtfs = [] 130 self._raw_columns = [] 131 self._table_columns = [] 132 self._stars = [] 133 self._join_hints = [] 134 self._semi_anti_join_tables = set() 135 self._column_index = set() 136 137 for node in self.walk(bfs=False): 138 if node is self.expression: 139 continue 140 141 if isinstance(node, exp.Dot) and node.is_star: 142 self._stars.append(node) 143 elif isinstance(node, exp.Column) and not isinstance(node, exp.Pseudocolumn): 144 self._column_index.add(id(node)) 145 146 if isinstance(node.this, exp.Star): 147 self._stars.append(node) 148 else: 149 self._raw_columns.append(node) 150 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 151 parent = node.parent 152 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 153 self._semi_anti_join_tables.add(node.alias_or_name) 154 155 self._tables.append(node) 156 elif isinstance(node, exp.JoinHint): 157 self._join_hints.append(node) 158 elif isinstance(node, exp.UDTF): 159 self._udtfs.append(node) 160 elif isinstance(node, exp.CTE): 161 self._ctes.append(node) 162 elif _is_derived_table(node) and _is_from_or_join(node): 163 self._derived_tables.append(node) 164 elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node): 165 self._subqueries.append(node) 166 elif isinstance(node, exp.TableColumn): 167 self._table_columns.append(node) 168 169 self._collected = True 170 171 def _ensure_collected(self): 172 if not self._collected: 173 self._collect() 174 175 def walk(self, bfs=True, prune=None): 176 return walk_in_scope(self.expression, bfs=bfs, prune=None) 177 178 def find(self, *expression_types, bfs=True): 179 return find_in_scope(self.expression, expression_types, bfs=bfs) 180 181 def find_all(self, *expression_types, bfs=True): 182 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 183 184 def replace(self, old, new): 185 """ 186 Replace `old` with `new`. 187 188 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 189 190 Args: 191 old (exp.Expression): old node 192 new (exp.Expression): new node 193 """ 194 old.replace(new) 195 self.clear_cache() 196 197 @property 198 def tables(self): 199 """ 200 List of tables in this scope. 201 202 Returns: 203 list[exp.Table]: tables 204 """ 205 self._ensure_collected() 206 return self._tables 207 208 @property 209 def ctes(self): 210 """ 211 List of CTEs in this scope. 212 213 Returns: 214 list[exp.CTE]: ctes 215 """ 216 self._ensure_collected() 217 return self._ctes 218 219 @property 220 def derived_tables(self): 221 """ 222 List of derived tables in this scope. 223 224 For example: 225 SELECT * FROM (SELECT ...) <- that's a derived table 226 227 Returns: 228 list[exp.Subquery]: derived tables 229 """ 230 self._ensure_collected() 231 return self._derived_tables 232 233 @property 234 def udtfs(self): 235 """ 236 List of "User Defined Tabular Functions" in this scope. 237 238 Returns: 239 list[exp.UDTF]: UDTFs 240 """ 241 self._ensure_collected() 242 return self._udtfs 243 244 @property 245 def subqueries(self): 246 """ 247 List of subqueries in this scope. 248 249 For example: 250 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 251 252 Returns: 253 list[exp.Select | exp.SetOperation]: subqueries 254 """ 255 self._ensure_collected() 256 return self._subqueries 257 258 @property 259 def stars(self) -> t.List[exp.Column | exp.Dot]: 260 """ 261 List of star expressions (columns or dots) in this scope. 262 """ 263 self._ensure_collected() 264 return self._stars 265 266 @property 267 def column_index(self) -> t.Set[int]: 268 """ 269 Set of column object IDs that belong to this scope's expression. 270 """ 271 self._ensure_collected() 272 return self._column_index 273 274 @property 275 def columns(self): 276 """ 277 List of columns in this scope. 278 279 Returns: 280 list[exp.Column]: Column instances in this scope, plus any 281 Columns that reference this scope from correlated subqueries. 282 """ 283 if self._columns is None: 284 self._ensure_collected() 285 columns = self._raw_columns 286 287 external_columns = [ 288 column 289 for scope in itertools.chain( 290 self.subquery_scopes, 291 self.udtf_scopes, 292 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 293 ) 294 for column in scope.external_columns 295 ] 296 297 named_selects = set(self.expression.named_selects) 298 299 self._columns = [] 300 for column in columns + external_columns: 301 ancestor = column.find_ancestor( 302 exp.Select, 303 exp.Qualify, 304 exp.Order, 305 exp.Having, 306 exp.Hint, 307 exp.Table, 308 exp.Star, 309 exp.Distinct, 310 ) 311 if ( 312 not ancestor 313 or column.table 314 or isinstance(ancestor, exp.Select) 315 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 316 or ( 317 isinstance(ancestor, (exp.Order, exp.Distinct)) 318 and ( 319 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 320 or not isinstance(ancestor.parent, exp.Select) 321 or column.name not in named_selects 322 ) 323 ) 324 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 325 ): 326 self._columns.append(column) 327 328 return self._columns 329 330 @property 331 def table_columns(self): 332 if self._table_columns is None: 333 self._ensure_collected() 334 335 return self._table_columns 336 337 @property 338 def selected_sources(self): 339 """ 340 Mapping of nodes and sources that are actually selected from in this scope. 341 342 That is, all tables in a schema are selectable at any point. But a 343 table only becomes a selected source if it's included in a FROM or JOIN clause. 344 345 Returns: 346 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 347 """ 348 if self._selected_sources is None: 349 result = {} 350 351 for name, node in self.references: 352 if name in self._semi_anti_join_tables: 353 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 354 # selected source 355 continue 356 357 if name in result: 358 raise OptimizeError(f"Alias already used: {name}") 359 if name in self.sources: 360 result[name] = (node, self.sources[name]) 361 362 self._selected_sources = result 363 return self._selected_sources 364 365 @property 366 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 367 if self._references is None: 368 self._references = [] 369 370 for table in self.tables: 371 self._references.append((table.alias_or_name, table)) 372 for expression in itertools.chain(self.derived_tables, self.udtfs): 373 self._references.append( 374 ( 375 _get_source_alias(expression), 376 expression if expression.args.get("pivots") else expression.unnest(), 377 ) 378 ) 379 380 return self._references 381 382 @property 383 def external_columns(self): 384 """ 385 Columns that appear to reference sources in outer scopes. 386 387 Returns: 388 list[exp.Column]: Column instances that don't reference sources in the current scope. 389 """ 390 if self._external_columns is None: 391 if isinstance(self.expression, exp.SetOperation): 392 left, right = self.union_scopes 393 self._external_columns = left.external_columns + right.external_columns 394 else: 395 self._external_columns = [ 396 c 397 for c in self.columns 398 if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables 399 ] 400 401 return self._external_columns 402 403 @property 404 def local_columns(self): 405 """ 406 Columns in this scope that are not external. 407 408 Returns: 409 list[exp.Column]: Column instances that reference sources in the current scope. 410 """ 411 if self._local_columns is None: 412 external_columns = set(self.external_columns) 413 self._local_columns = [c for c in self.columns if c not in external_columns] 414 415 return self._local_columns 416 417 @property 418 def unqualified_columns(self): 419 """ 420 Unqualified columns in the current scope. 421 422 Returns: 423 list[exp.Column]: Unqualified columns 424 """ 425 return [c for c in self.columns if not c.table] 426 427 @property 428 def join_hints(self): 429 """ 430 Hints that exist in the scope that reference tables 431 432 Returns: 433 list[exp.JoinHint]: Join hints that are referenced within the scope 434 """ 435 if self._join_hints is None: 436 return [] 437 return self._join_hints 438 439 @property 440 def pivots(self): 441 if not self._pivots: 442 self._pivots = [ 443 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 444 ] 445 446 return self._pivots 447 448 @property 449 def semi_or_anti_join_tables(self): 450 return self._semi_anti_join_tables or set() 451 452 def source_columns(self, source_name): 453 """ 454 Get all columns in the current scope for a particular source. 455 456 Args: 457 source_name (str): Name of the source 458 Returns: 459 list[exp.Column]: Column instances that reference `source_name` 460 """ 461 return [column for column in self.columns if column.table == source_name] 462 463 @property 464 def is_subquery(self): 465 """Determine if this scope is a subquery""" 466 return self.scope_type == ScopeType.SUBQUERY 467 468 @property 469 def is_derived_table(self): 470 """Determine if this scope is a derived table""" 471 return self.scope_type == ScopeType.DERIVED_TABLE 472 473 @property 474 def is_union(self): 475 """Determine if this scope is a union""" 476 return self.scope_type == ScopeType.UNION 477 478 @property 479 def is_cte(self): 480 """Determine if this scope is a common table expression""" 481 return self.scope_type == ScopeType.CTE 482 483 @property 484 def is_root(self): 485 """Determine if this is the root scope""" 486 return self.scope_type == ScopeType.ROOT 487 488 @property 489 def is_udtf(self): 490 """Determine if this scope is a UDTF (User Defined Table Function)""" 491 return self.scope_type == ScopeType.UDTF 492 493 @property 494 def is_correlated_subquery(self): 495 """Determine if this scope is a correlated subquery""" 496 return bool(self.can_be_correlated and self.external_columns) 497 498 def rename_source(self, old_name, new_name): 499 """Rename a source in this scope""" 500 old_name = old_name or "" 501 if old_name in self.sources: 502 self.sources[new_name] = self.sources.pop(old_name) 503 504 def add_source(self, name, source): 505 """Add a source to this scope""" 506 self.sources[name] = source 507 self.clear_cache() 508 509 def remove_source(self, name): 510 """Remove a source from this scope""" 511 self.sources.pop(name, None) 512 self.clear_cache() 513 514 def __repr__(self): 515 return f"Scope<{self.expression.sql()}>" 516 517 def traverse(self): 518 """ 519 Traverse the scope tree from this node. 520 521 Yields: 522 Scope: scope instances in depth-first-search post-order 523 """ 524 stack = [self] 525 result = [] 526 while stack: 527 scope = stack.pop() 528 result.append(scope) 529 stack.extend( 530 itertools.chain( 531 scope.cte_scopes, 532 scope.union_scopes, 533 scope.table_scopes, 534 scope.subquery_scopes, 535 ) 536 ) 537 538 yield from reversed(result) 539 540 def ref_count(self): 541 """ 542 Count the number of times each scope in this tree is referenced. 543 544 Returns: 545 dict[int, int]: Mapping of Scope instance ID to reference count 546 """ 547 scope_ref_count = defaultdict(lambda: 0) 548 549 for scope in self.traverse(): 550 for _, source in scope.selected_sources.values(): 551 scope_ref_count[id(source)] += 1 552 553 for name in scope._semi_anti_join_tables: 554 # semi/anti join sources are not actually selected but we still need to 555 # increment their ref count to avoid them being optimized away 556 if name in scope.sources: 557 scope_ref_count[id(scope.sources[name])] += 1 558 559 return scope_ref_count 560 561 562def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 563 """ 564 Traverse an expression by its "scopes". 565 566 "Scope" represents the current context of a Select statement. 567 568 This is helpful for optimizing queries, where we need more information than 569 the expression tree itself. For example, we might care about the source 570 names within a subquery. Returns a list because a generator could result in 571 incomplete properties which is confusing. 572 573 Examples: 574 >>> import sqlglot 575 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 576 >>> scopes = traverse_scope(expression) 577 >>> scopes[0].expression.sql(), list(scopes[0].sources) 578 ('SELECT a FROM x', ['x']) 579 >>> scopes[1].expression.sql(), list(scopes[1].sources) 580 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 581 582 Args: 583 expression: Expression to traverse 584 585 Returns: 586 A list of the created scope instances 587 """ 588 if isinstance(expression, TRAVERSABLES): 589 return list(_traverse_scope(Scope(expression))) 590 return [] 591 592 593def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 594 """ 595 Build a scope tree. 596 597 Args: 598 expression: Expression to build the scope tree for. 599 600 Returns: 601 The root scope 602 """ 603 return seq_get(traverse_scope(expression), -1) 604 605 606def _traverse_scope(scope): 607 expression = scope.expression 608 609 if isinstance(expression, exp.Select): 610 yield from _traverse_select(scope) 611 elif isinstance(expression, exp.SetOperation): 612 yield from _traverse_ctes(scope) 613 yield from _traverse_union(scope) 614 return 615 elif isinstance(expression, exp.Subquery): 616 if scope.is_root: 617 yield from _traverse_select(scope) 618 else: 619 yield from _traverse_subqueries(scope) 620 elif isinstance(expression, exp.Table): 621 yield from _traverse_tables(scope) 622 elif isinstance(expression, exp.UDTF): 623 yield from _traverse_udtfs(scope) 624 elif isinstance(expression, exp.DDL): 625 if isinstance(expression.expression, exp.Query): 626 yield from _traverse_ctes(scope) 627 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 628 return 629 elif isinstance(expression, exp.DML): 630 yield from _traverse_ctes(scope) 631 for query in find_all_in_scope(expression, exp.Query): 632 # This check ensures we don't yield the CTE/nested queries twice 633 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 634 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 635 return 636 else: 637 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 638 return 639 640 yield scope 641 642 643def _traverse_select(scope): 644 yield from _traverse_ctes(scope) 645 yield from _traverse_tables(scope) 646 yield from _traverse_subqueries(scope) 647 648 649def _traverse_union(scope): 650 prev_scope = None 651 union_scope_stack = [scope] 652 expression_stack = [scope.expression.right, scope.expression.left] 653 654 while expression_stack: 655 expression = expression_stack.pop() 656 union_scope = union_scope_stack[-1] 657 658 new_scope = union_scope.branch( 659 expression, 660 outer_columns=union_scope.outer_columns, 661 scope_type=ScopeType.UNION, 662 ) 663 664 if isinstance(expression, exp.SetOperation): 665 yield from _traverse_ctes(new_scope) 666 667 union_scope_stack.append(new_scope) 668 expression_stack.extend([expression.right, expression.left]) 669 continue 670 671 for scope in _traverse_scope(new_scope): 672 yield scope 673 674 if prev_scope: 675 union_scope_stack.pop() 676 union_scope.union_scopes = [prev_scope, scope] 677 prev_scope = union_scope 678 679 yield union_scope 680 else: 681 prev_scope = scope 682 683 684def _traverse_ctes(scope): 685 sources = {} 686 687 for cte in scope.ctes: 688 cte_name = cte.alias 689 690 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 691 # thus the recursive scope is the first section of the union. 692 with_ = scope.expression.args.get("with_") 693 if with_ and with_.recursive: 694 union = cte.this 695 696 if isinstance(union, exp.SetOperation): 697 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 698 699 child_scope = None 700 701 for child_scope in _traverse_scope( 702 scope.branch( 703 cte.this, 704 cte_sources=sources, 705 outer_columns=cte.alias_column_names, 706 scope_type=ScopeType.CTE, 707 ) 708 ): 709 yield child_scope 710 711 # append the final child_scope yielded 712 if child_scope: 713 sources[cte_name] = child_scope 714 scope.cte_scopes.append(child_scope) 715 716 scope.sources.update(sources) 717 scope.cte_sources.update(sources) 718 719 720def _is_derived_table(expression: exp.Subquery) -> bool: 721 """ 722 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 723 as it doesn't introduce a new scope. If an alias is present, it shadows all names 724 under the Subquery, so that's one exception to this rule. 725 """ 726 return isinstance(expression, exp.Subquery) and bool( 727 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 728 ) 729 730 731def _is_from_or_join(expression: exp.Expression) -> bool: 732 """ 733 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 734 """ 735 parent = expression.parent 736 737 # Subqueries can be arbitrarily nested 738 while isinstance(parent, exp.Subquery): 739 parent = parent.parent 740 741 return isinstance(parent, (exp.From, exp.Join)) 742 743 744def _traverse_tables(scope): 745 sources = {} 746 747 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 748 expressions = [] 749 from_ = scope.expression.args.get("from_") 750 if from_: 751 expressions.append(from_.this) 752 753 for join in scope.expression.args.get("joins") or []: 754 expressions.append(join.this) 755 756 if isinstance(scope.expression, exp.Table): 757 expressions.append(scope.expression) 758 759 expressions.extend(scope.expression.args.get("laterals") or []) 760 761 for expression in expressions: 762 if isinstance(expression, exp.Final): 763 expression = expression.this 764 if isinstance(expression, exp.Table): 765 table_name = expression.name 766 source_name = expression.alias_or_name 767 768 if table_name in scope.sources and not expression.db: 769 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 770 # it is pivoted, because then we get back a new table and hence a new source. 771 pivots = expression.args.get("pivots") 772 if pivots: 773 sources[pivots[0].alias] = expression 774 else: 775 sources[source_name] = scope.sources[table_name] 776 elif source_name in sources: 777 sources[find_new_name(sources, table_name)] = expression 778 else: 779 sources[source_name] = expression 780 781 # Make sure to not include the joins twice 782 if expression is not scope.expression: 783 expressions.extend(join.this for join in expression.args.get("joins") or []) 784 785 continue 786 787 if not isinstance(expression, exp.DerivedTable): 788 continue 789 790 if isinstance(expression, exp.UDTF): 791 lateral_sources = sources 792 scope_type = ScopeType.UDTF 793 scopes = scope.udtf_scopes 794 elif _is_derived_table(expression): 795 lateral_sources = None 796 scope_type = ScopeType.DERIVED_TABLE 797 scopes = scope.derived_table_scopes 798 expressions.extend(join.this for join in expression.args.get("joins") or []) 799 else: 800 # Makes sure we check for possible sources in nested table constructs 801 expressions.append(expression.this) 802 expressions.extend(join.this for join in expression.args.get("joins") or []) 803 continue 804 805 child_scope = None 806 807 for child_scope in _traverse_scope( 808 scope.branch( 809 expression, 810 lateral_sources=lateral_sources, 811 outer_columns=expression.alias_column_names, 812 scope_type=scope_type, 813 ) 814 ): 815 yield child_scope 816 817 # Tables without aliases will be set as "" 818 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 819 # Until then, this means that only a single, unaliased derived table is allowed (rather, 820 # the latest one wins. 821 sources[_get_source_alias(expression)] = child_scope 822 823 # append the final child_scope yielded 824 if child_scope: 825 scopes.append(child_scope) 826 scope.table_scopes.append(child_scope) 827 828 scope.sources.update(sources) 829 830 831def _traverse_subqueries(scope): 832 for subquery in scope.subqueries: 833 top = None 834 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 835 yield child_scope 836 top = child_scope 837 scope.subquery_scopes.append(top) 838 839 840def _traverse_udtfs(scope): 841 if isinstance(scope.expression, exp.Unnest): 842 expressions = scope.expression.expressions 843 elif isinstance(scope.expression, exp.Lateral): 844 expressions = [scope.expression.this] 845 else: 846 expressions = [] 847 848 sources = {} 849 for expression in expressions: 850 if isinstance(expression, exp.Subquery): 851 top = None 852 for child_scope in _traverse_scope( 853 scope.branch( 854 expression, 855 scope_type=ScopeType.SUBQUERY, 856 outer_columns=expression.alias_column_names, 857 ) 858 ): 859 yield child_scope 860 top = child_scope 861 sources[_get_source_alias(expression)] = child_scope 862 863 scope.subquery_scopes.append(top) 864 865 scope.sources.update(sources) 866 867 868def walk_in_scope(expression, bfs=True, prune=None): 869 """ 870 Returns a generator object which visits all nodes in the syntrax tree, stopping at 871 nodes that start child scopes. 872 873 Args: 874 expression (exp.Expression): 875 bfs (bool): if set to True the BFS traversal order will be applied, 876 otherwise the DFS traversal will be used instead. 877 prune ((node, parent, arg_key) -> bool): callable that returns True if 878 the generator should stop traversing this branch of the tree. 879 880 Yields: 881 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 882 """ 883 # We'll use this variable to pass state into the dfs generator. 884 # Whenever we set it to True, we exclude a subtree from traversal. 885 crossed_scope_boundary = False 886 887 for node in expression.walk( 888 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 889 ): 890 crossed_scope_boundary = False 891 892 yield node 893 894 if node is expression: 895 continue 896 897 if ( 898 isinstance(node, exp.CTE) 899 or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node)) 900 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 901 or isinstance(node, exp.UNWRAPPED_QUERIES) 902 ): 903 crossed_scope_boundary = True 904 905 if isinstance(node, (exp.Subquery, exp.UDTF)): 906 # The following args are not actually in the inner scope, so we should visit them 907 for key in ("joins", "laterals", "pivots"): 908 for arg in node.args.get(key) or []: 909 yield from walk_in_scope(arg, bfs=bfs) 910 911 912def find_all_in_scope(expression, expression_types, bfs=True): 913 """ 914 Returns a generator object which visits all nodes in this scope and only yields those that 915 match at least one of the specified expression types. 916 917 This does NOT traverse into subscopes. 918 919 Args: 920 expression (exp.Expression): 921 expression_types (tuple[type]|type): the expression type(s) to match. 922 bfs (bool): True to use breadth-first search, False to use depth-first. 923 924 Yields: 925 exp.Expression: nodes 926 """ 927 for expression in walk_in_scope(expression, bfs=bfs): 928 if isinstance(expression, tuple(ensure_collection(expression_types))): 929 yield expression 930 931 932def find_in_scope(expression, expression_types, bfs=True): 933 """ 934 Returns the first node in this scope which matches at least one of the specified types. 935 936 This does NOT traverse into subscopes. 937 938 Args: 939 expression (exp.Expression): 940 expression_types (tuple[type]|type): the expression type(s) to match. 941 bfs (bool): True to use breadth-first search, False to use depth-first. 942 943 Returns: 944 exp.Expression: the node which matches the criteria or None if no node matching 945 the criteria was found. 946 """ 947 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) 948 949 950def _get_source_alias(expression): 951 alias_arg = expression.args.get("alias") 952 alias_name = expression.alias 953 954 if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1: 955 alias_name = alias_arg.columns[0].name 956 957 return alias_name
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._local_columns = None 103 self._join_hints = None 104 self._pivots = None 105 self._references = None 106 self._semi_anti_join_tables = None 107 self._column_index = None 108 109 def branch( 110 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 111 ): 112 """Branch from the current scope to a new, inner scope""" 113 return Scope( 114 expression=expression.unnest(), 115 sources=sources.copy() if sources else None, 116 parent=self, 117 scope_type=scope_type, 118 cte_sources={**self.cte_sources, **(cte_sources or {})}, 119 lateral_sources=lateral_sources.copy() if lateral_sources else None, 120 can_be_correlated=self.can_be_correlated 121 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 122 **kwargs, 123 ) 124 125 def _collect(self): 126 self._tables = [] 127 self._ctes = [] 128 self._subqueries = [] 129 self._derived_tables = [] 130 self._udtfs = [] 131 self._raw_columns = [] 132 self._table_columns = [] 133 self._stars = [] 134 self._join_hints = [] 135 self._semi_anti_join_tables = set() 136 self._column_index = set() 137 138 for node in self.walk(bfs=False): 139 if node is self.expression: 140 continue 141 142 if isinstance(node, exp.Dot) and node.is_star: 143 self._stars.append(node) 144 elif isinstance(node, exp.Column) and not isinstance(node, exp.Pseudocolumn): 145 self._column_index.add(id(node)) 146 147 if isinstance(node.this, exp.Star): 148 self._stars.append(node) 149 else: 150 self._raw_columns.append(node) 151 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 152 parent = node.parent 153 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 154 self._semi_anti_join_tables.add(node.alias_or_name) 155 156 self._tables.append(node) 157 elif isinstance(node, exp.JoinHint): 158 self._join_hints.append(node) 159 elif isinstance(node, exp.UDTF): 160 self._udtfs.append(node) 161 elif isinstance(node, exp.CTE): 162 self._ctes.append(node) 163 elif _is_derived_table(node) and _is_from_or_join(node): 164 self._derived_tables.append(node) 165 elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node): 166 self._subqueries.append(node) 167 elif isinstance(node, exp.TableColumn): 168 self._table_columns.append(node) 169 170 self._collected = True 171 172 def _ensure_collected(self): 173 if not self._collected: 174 self._collect() 175 176 def walk(self, bfs=True, prune=None): 177 return walk_in_scope(self.expression, bfs=bfs, prune=None) 178 179 def find(self, *expression_types, bfs=True): 180 return find_in_scope(self.expression, expression_types, bfs=bfs) 181 182 def find_all(self, *expression_types, bfs=True): 183 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 184 185 def replace(self, old, new): 186 """ 187 Replace `old` with `new`. 188 189 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 190 191 Args: 192 old (exp.Expression): old node 193 new (exp.Expression): new node 194 """ 195 old.replace(new) 196 self.clear_cache() 197 198 @property 199 def tables(self): 200 """ 201 List of tables in this scope. 202 203 Returns: 204 list[exp.Table]: tables 205 """ 206 self._ensure_collected() 207 return self._tables 208 209 @property 210 def ctes(self): 211 """ 212 List of CTEs in this scope. 213 214 Returns: 215 list[exp.CTE]: ctes 216 """ 217 self._ensure_collected() 218 return self._ctes 219 220 @property 221 def derived_tables(self): 222 """ 223 List of derived tables in this scope. 224 225 For example: 226 SELECT * FROM (SELECT ...) <- that's a derived table 227 228 Returns: 229 list[exp.Subquery]: derived tables 230 """ 231 self._ensure_collected() 232 return self._derived_tables 233 234 @property 235 def udtfs(self): 236 """ 237 List of "User Defined Tabular Functions" in this scope. 238 239 Returns: 240 list[exp.UDTF]: UDTFs 241 """ 242 self._ensure_collected() 243 return self._udtfs 244 245 @property 246 def subqueries(self): 247 """ 248 List of subqueries in this scope. 249 250 For example: 251 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 252 253 Returns: 254 list[exp.Select | exp.SetOperation]: subqueries 255 """ 256 self._ensure_collected() 257 return self._subqueries 258 259 @property 260 def stars(self) -> t.List[exp.Column | exp.Dot]: 261 """ 262 List of star expressions (columns or dots) in this scope. 263 """ 264 self._ensure_collected() 265 return self._stars 266 267 @property 268 def column_index(self) -> t.Set[int]: 269 """ 270 Set of column object IDs that belong to this scope's expression. 271 """ 272 self._ensure_collected() 273 return self._column_index 274 275 @property 276 def columns(self): 277 """ 278 List of columns in this scope. 279 280 Returns: 281 list[exp.Column]: Column instances in this scope, plus any 282 Columns that reference this scope from correlated subqueries. 283 """ 284 if self._columns is None: 285 self._ensure_collected() 286 columns = self._raw_columns 287 288 external_columns = [ 289 column 290 for scope in itertools.chain( 291 self.subquery_scopes, 292 self.udtf_scopes, 293 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 294 ) 295 for column in scope.external_columns 296 ] 297 298 named_selects = set(self.expression.named_selects) 299 300 self._columns = [] 301 for column in columns + external_columns: 302 ancestor = column.find_ancestor( 303 exp.Select, 304 exp.Qualify, 305 exp.Order, 306 exp.Having, 307 exp.Hint, 308 exp.Table, 309 exp.Star, 310 exp.Distinct, 311 ) 312 if ( 313 not ancestor 314 or column.table 315 or isinstance(ancestor, exp.Select) 316 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 317 or ( 318 isinstance(ancestor, (exp.Order, exp.Distinct)) 319 and ( 320 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 321 or not isinstance(ancestor.parent, exp.Select) 322 or column.name not in named_selects 323 ) 324 ) 325 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 326 ): 327 self._columns.append(column) 328 329 return self._columns 330 331 @property 332 def table_columns(self): 333 if self._table_columns is None: 334 self._ensure_collected() 335 336 return self._table_columns 337 338 @property 339 def selected_sources(self): 340 """ 341 Mapping of nodes and sources that are actually selected from in this scope. 342 343 That is, all tables in a schema are selectable at any point. But a 344 table only becomes a selected source if it's included in a FROM or JOIN clause. 345 346 Returns: 347 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 348 """ 349 if self._selected_sources is None: 350 result = {} 351 352 for name, node in self.references: 353 if name in self._semi_anti_join_tables: 354 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 355 # selected source 356 continue 357 358 if name in result: 359 raise OptimizeError(f"Alias already used: {name}") 360 if name in self.sources: 361 result[name] = (node, self.sources[name]) 362 363 self._selected_sources = result 364 return self._selected_sources 365 366 @property 367 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 368 if self._references is None: 369 self._references = [] 370 371 for table in self.tables: 372 self._references.append((table.alias_or_name, table)) 373 for expression in itertools.chain(self.derived_tables, self.udtfs): 374 self._references.append( 375 ( 376 _get_source_alias(expression), 377 expression if expression.args.get("pivots") else expression.unnest(), 378 ) 379 ) 380 381 return self._references 382 383 @property 384 def external_columns(self): 385 """ 386 Columns that appear to reference sources in outer scopes. 387 388 Returns: 389 list[exp.Column]: Column instances that don't reference sources in the current scope. 390 """ 391 if self._external_columns is None: 392 if isinstance(self.expression, exp.SetOperation): 393 left, right = self.union_scopes 394 self._external_columns = left.external_columns + right.external_columns 395 else: 396 self._external_columns = [ 397 c 398 for c in self.columns 399 if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables 400 ] 401 402 return self._external_columns 403 404 @property 405 def local_columns(self): 406 """ 407 Columns in this scope that are not external. 408 409 Returns: 410 list[exp.Column]: Column instances that reference sources in the current scope. 411 """ 412 if self._local_columns is None: 413 external_columns = set(self.external_columns) 414 self._local_columns = [c for c in self.columns if c not in external_columns] 415 416 return self._local_columns 417 418 @property 419 def unqualified_columns(self): 420 """ 421 Unqualified columns in the current scope. 422 423 Returns: 424 list[exp.Column]: Unqualified columns 425 """ 426 return [c for c in self.columns if not c.table] 427 428 @property 429 def join_hints(self): 430 """ 431 Hints that exist in the scope that reference tables 432 433 Returns: 434 list[exp.JoinHint]: Join hints that are referenced within the scope 435 """ 436 if self._join_hints is None: 437 return [] 438 return self._join_hints 439 440 @property 441 def pivots(self): 442 if not self._pivots: 443 self._pivots = [ 444 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 445 ] 446 447 return self._pivots 448 449 @property 450 def semi_or_anti_join_tables(self): 451 return self._semi_anti_join_tables or set() 452 453 def source_columns(self, source_name): 454 """ 455 Get all columns in the current scope for a particular source. 456 457 Args: 458 source_name (str): Name of the source 459 Returns: 460 list[exp.Column]: Column instances that reference `source_name` 461 """ 462 return [column for column in self.columns if column.table == source_name] 463 464 @property 465 def is_subquery(self): 466 """Determine if this scope is a subquery""" 467 return self.scope_type == ScopeType.SUBQUERY 468 469 @property 470 def is_derived_table(self): 471 """Determine if this scope is a derived table""" 472 return self.scope_type == ScopeType.DERIVED_TABLE 473 474 @property 475 def is_union(self): 476 """Determine if this scope is a union""" 477 return self.scope_type == ScopeType.UNION 478 479 @property 480 def is_cte(self): 481 """Determine if this scope is a common table expression""" 482 return self.scope_type == ScopeType.CTE 483 484 @property 485 def is_root(self): 486 """Determine if this is the root scope""" 487 return self.scope_type == ScopeType.ROOT 488 489 @property 490 def is_udtf(self): 491 """Determine if this scope is a UDTF (User Defined Table Function)""" 492 return self.scope_type == ScopeType.UDTF 493 494 @property 495 def is_correlated_subquery(self): 496 """Determine if this scope is a correlated subquery""" 497 return bool(self.can_be_correlated and self.external_columns) 498 499 def rename_source(self, old_name, new_name): 500 """Rename a source in this scope""" 501 old_name = old_name or "" 502 if old_name in self.sources: 503 self.sources[new_name] = self.sources.pop(old_name) 504 505 def add_source(self, name, source): 506 """Add a source to this scope""" 507 self.sources[name] = source 508 self.clear_cache() 509 510 def remove_source(self, name): 511 """Remove a source from this scope""" 512 self.sources.pop(name, None) 513 self.clear_cache() 514 515 def __repr__(self): 516 return f"Scope<{self.expression.sql()}>" 517 518 def traverse(self): 519 """ 520 Traverse the scope tree from this node. 521 522 Yields: 523 Scope: scope instances in depth-first-search post-order 524 """ 525 stack = [self] 526 result = [] 527 while stack: 528 scope = stack.pop() 529 result.append(scope) 530 stack.extend( 531 itertools.chain( 532 scope.cte_scopes, 533 scope.union_scopes, 534 scope.table_scopes, 535 scope.subquery_scopes, 536 ) 537 ) 538 539 yield from reversed(result) 540 541 def ref_count(self): 542 """ 543 Count the number of times each scope in this tree is referenced. 544 545 Returns: 546 dict[int, int]: Mapping of Scope instance ID to reference count 547 """ 548 scope_ref_count = defaultdict(lambda: 0) 549 550 for scope in self.traverse(): 551 for _, source in scope.selected_sources.values(): 552 scope_ref_count[id(source)] += 1 553 554 for name in scope._semi_anti_join_tables: 555 # semi/anti join sources are not actually selected but we still need to 556 # increment their ref count to avoid them being optimized away 557 if name in scope.sources: 558 scope_ref_count[id(scope.sources[name])] += 1 559 560 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]for itsouter_columns - parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._local_columns = None 103 self._join_hints = None 104 self._pivots = None 105 self._references = None 106 self._semi_anti_join_tables = None 107 self._column_index = None
109 def branch( 110 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 111 ): 112 """Branch from the current scope to a new, inner scope""" 113 return Scope( 114 expression=expression.unnest(), 115 sources=sources.copy() if sources else None, 116 parent=self, 117 scope_type=scope_type, 118 cte_sources={**self.cte_sources, **(cte_sources or {})}, 119 lateral_sources=lateral_sources.copy() if lateral_sources else None, 120 can_be_correlated=self.can_be_correlated 121 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 122 **kwargs, 123 )
Branch from the current scope to a new, inner scope
185 def replace(self, old, new): 186 """ 187 Replace `old` with `new`. 188 189 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 190 191 Args: 192 old (exp.Expression): old node 193 new (exp.Expression): new node 194 """ 195 old.replace(new) 196 self.clear_cache()
Replace old with new.
This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
198 @property 199 def tables(self): 200 """ 201 List of tables in this scope. 202 203 Returns: 204 list[exp.Table]: tables 205 """ 206 self._ensure_collected() 207 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
209 @property 210 def ctes(self): 211 """ 212 List of CTEs in this scope. 213 214 Returns: 215 list[exp.CTE]: ctes 216 """ 217 self._ensure_collected() 218 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
220 @property 221 def derived_tables(self): 222 """ 223 List of derived tables in this scope. 224 225 For example: 226 SELECT * FROM (SELECT ...) <- that's a derived table 227 228 Returns: 229 list[exp.Subquery]: derived tables 230 """ 231 self._ensure_collected() 232 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
234 @property 235 def udtfs(self): 236 """ 237 List of "User Defined Tabular Functions" in this scope. 238 239 Returns: 240 list[exp.UDTF]: UDTFs 241 """ 242 self._ensure_collected() 243 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
245 @property 246 def subqueries(self): 247 """ 248 List of subqueries in this scope. 249 250 For example: 251 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 252 253 Returns: 254 list[exp.Select | exp.SetOperation]: subqueries 255 """ 256 self._ensure_collected() 257 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
259 @property 260 def stars(self) -> t.List[exp.Column | exp.Dot]: 261 """ 262 List of star expressions (columns or dots) in this scope. 263 """ 264 self._ensure_collected() 265 return self._stars
List of star expressions (columns or dots) in this scope.
267 @property 268 def column_index(self) -> t.Set[int]: 269 """ 270 Set of column object IDs that belong to this scope's expression. 271 """ 272 self._ensure_collected() 273 return self._column_index
Set of column object IDs that belong to this scope's expression.
275 @property 276 def columns(self): 277 """ 278 List of columns in this scope. 279 280 Returns: 281 list[exp.Column]: Column instances in this scope, plus any 282 Columns that reference this scope from correlated subqueries. 283 """ 284 if self._columns is None: 285 self._ensure_collected() 286 columns = self._raw_columns 287 288 external_columns = [ 289 column 290 for scope in itertools.chain( 291 self.subquery_scopes, 292 self.udtf_scopes, 293 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 294 ) 295 for column in scope.external_columns 296 ] 297 298 named_selects = set(self.expression.named_selects) 299 300 self._columns = [] 301 for column in columns + external_columns: 302 ancestor = column.find_ancestor( 303 exp.Select, 304 exp.Qualify, 305 exp.Order, 306 exp.Having, 307 exp.Hint, 308 exp.Table, 309 exp.Star, 310 exp.Distinct, 311 ) 312 if ( 313 not ancestor 314 or column.table 315 or isinstance(ancestor, exp.Select) 316 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 317 or ( 318 isinstance(ancestor, (exp.Order, exp.Distinct)) 319 and ( 320 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 321 or not isinstance(ancestor.parent, exp.Select) 322 or column.name not in named_selects 323 ) 324 ) 325 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 326 ): 327 self._columns.append(column) 328 329 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
338 @property 339 def selected_sources(self): 340 """ 341 Mapping of nodes and sources that are actually selected from in this scope. 342 343 That is, all tables in a schema are selectable at any point. But a 344 table only becomes a selected source if it's included in a FROM or JOIN clause. 345 346 Returns: 347 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 348 """ 349 if self._selected_sources is None: 350 result = {} 351 352 for name, node in self.references: 353 if name in self._semi_anti_join_tables: 354 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 355 # selected source 356 continue 357 358 if name in result: 359 raise OptimizeError(f"Alias already used: {name}") 360 if name in self.sources: 361 result[name] = (node, self.sources[name]) 362 363 self._selected_sources = result 364 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
366 @property 367 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 368 if self._references is None: 369 self._references = [] 370 371 for table in self.tables: 372 self._references.append((table.alias_or_name, table)) 373 for expression in itertools.chain(self.derived_tables, self.udtfs): 374 self._references.append( 375 ( 376 _get_source_alias(expression), 377 expression if expression.args.get("pivots") else expression.unnest(), 378 ) 379 ) 380 381 return self._references
383 @property 384 def external_columns(self): 385 """ 386 Columns that appear to reference sources in outer scopes. 387 388 Returns: 389 list[exp.Column]: Column instances that don't reference sources in the current scope. 390 """ 391 if self._external_columns is None: 392 if isinstance(self.expression, exp.SetOperation): 393 left, right = self.union_scopes 394 self._external_columns = left.external_columns + right.external_columns 395 else: 396 self._external_columns = [ 397 c 398 for c in self.columns 399 if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables 400 ] 401 402 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
404 @property 405 def local_columns(self): 406 """ 407 Columns in this scope that are not external. 408 409 Returns: 410 list[exp.Column]: Column instances that reference sources in the current scope. 411 """ 412 if self._local_columns is None: 413 external_columns = set(self.external_columns) 414 self._local_columns = [c for c in self.columns if c not in external_columns] 415 416 return self._local_columns
Columns in this scope that are not external.
Returns:
list[exp.Column]: Column instances that reference sources in the current scope.
418 @property 419 def unqualified_columns(self): 420 """ 421 Unqualified columns in the current scope. 422 423 Returns: 424 list[exp.Column]: Unqualified columns 425 """ 426 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
428 @property 429 def join_hints(self): 430 """ 431 Hints that exist in the scope that reference tables 432 433 Returns: 434 list[exp.JoinHint]: Join hints that are referenced within the scope 435 """ 436 if self._join_hints is None: 437 return [] 438 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
453 def source_columns(self, source_name): 454 """ 455 Get all columns in the current scope for a particular source. 456 457 Args: 458 source_name (str): Name of the source 459 Returns: 460 list[exp.Column]: Column instances that reference `source_name` 461 """ 462 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
464 @property 465 def is_subquery(self): 466 """Determine if this scope is a subquery""" 467 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
469 @property 470 def is_derived_table(self): 471 """Determine if this scope is a derived table""" 472 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
474 @property 475 def is_union(self): 476 """Determine if this scope is a union""" 477 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
479 @property 480 def is_cte(self): 481 """Determine if this scope is a common table expression""" 482 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
484 @property 485 def is_root(self): 486 """Determine if this is the root scope""" 487 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
489 @property 490 def is_udtf(self): 491 """Determine if this scope is a UDTF (User Defined Table Function)""" 492 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
499 def rename_source(self, old_name, new_name): 500 """Rename a source in this scope""" 501 old_name = old_name or "" 502 if old_name in self.sources: 503 self.sources[new_name] = self.sources.pop(old_name)
Rename a source in this scope
505 def add_source(self, name, source): 506 """Add a source to this scope""" 507 self.sources[name] = source 508 self.clear_cache()
Add a source to this scope
510 def remove_source(self, name): 511 """Remove a source from this scope""" 512 self.sources.pop(name, None) 513 self.clear_cache()
Remove a source from this scope
518 def traverse(self): 519 """ 520 Traverse the scope tree from this node. 521 522 Yields: 523 Scope: scope instances in depth-first-search post-order 524 """ 525 stack = [self] 526 result = [] 527 while stack: 528 scope = stack.pop() 529 result.append(scope) 530 stack.extend( 531 itertools.chain( 532 scope.cte_scopes, 533 scope.union_scopes, 534 scope.table_scopes, 535 scope.subquery_scopes, 536 ) 537 ) 538 539 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
541 def ref_count(self): 542 """ 543 Count the number of times each scope in this tree is referenced. 544 545 Returns: 546 dict[int, int]: Mapping of Scope instance ID to reference count 547 """ 548 scope_ref_count = defaultdict(lambda: 0) 549 550 for scope in self.traverse(): 551 for _, source in scope.selected_sources.values(): 552 scope_ref_count[id(source)] += 1 553 554 for name in scope._semi_anti_join_tables: 555 # semi/anti join sources are not actually selected but we still need to 556 # increment their ref count to avoid them being optimized away 557 if name in scope.sources: 558 scope_ref_count[id(scope.sources[name])] += 1 559 560 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
563def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 564 """ 565 Traverse an expression by its "scopes". 566 567 "Scope" represents the current context of a Select statement. 568 569 This is helpful for optimizing queries, where we need more information than 570 the expression tree itself. For example, we might care about the source 571 names within a subquery. Returns a list because a generator could result in 572 incomplete properties which is confusing. 573 574 Examples: 575 >>> import sqlglot 576 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 577 >>> scopes = traverse_scope(expression) 578 >>> scopes[0].expression.sql(), list(scopes[0].sources) 579 ('SELECT a FROM x', ['x']) 580 >>> scopes[1].expression.sql(), list(scopes[1].sources) 581 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 582 583 Args: 584 expression: Expression to traverse 585 586 Returns: 587 A list of the created scope instances 588 """ 589 if isinstance(expression, TRAVERSABLES): 590 return list(_traverse_scope(Scope(expression))) 591 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
594def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 595 """ 596 Build a scope tree. 597 598 Args: 599 expression: Expression to build the scope tree for. 600 601 Returns: 602 The root scope 603 """ 604 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
869def walk_in_scope(expression, bfs=True, prune=None): 870 """ 871 Returns a generator object which visits all nodes in the syntrax tree, stopping at 872 nodes that start child scopes. 873 874 Args: 875 expression (exp.Expression): 876 bfs (bool): if set to True the BFS traversal order will be applied, 877 otherwise the DFS traversal will be used instead. 878 prune ((node, parent, arg_key) -> bool): callable that returns True if 879 the generator should stop traversing this branch of the tree. 880 881 Yields: 882 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 883 """ 884 # We'll use this variable to pass state into the dfs generator. 885 # Whenever we set it to True, we exclude a subtree from traversal. 886 crossed_scope_boundary = False 887 888 for node in expression.walk( 889 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 890 ): 891 crossed_scope_boundary = False 892 893 yield node 894 895 if node is expression: 896 continue 897 898 if ( 899 isinstance(node, exp.CTE) 900 or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node)) 901 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 902 or isinstance(node, exp.UNWRAPPED_QUERIES) 903 ): 904 crossed_scope_boundary = True 905 906 if isinstance(node, (exp.Subquery, exp.UDTF)): 907 # The following args are not actually in the inner scope, so we should visit them 908 for key in ("joins", "laterals", "pivots"): 909 for arg in node.args.get(key) or []: 910 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
913def find_all_in_scope(expression, expression_types, bfs=True): 914 """ 915 Returns a generator object which visits all nodes in this scope and only yields those that 916 match at least one of the specified expression types. 917 918 This does NOT traverse into subscopes. 919 920 Args: 921 expression (exp.Expression): 922 expression_types (tuple[type]|type): the expression type(s) to match. 923 bfs (bool): True to use breadth-first search, False to use depth-first. 924 925 Yields: 926 exp.Expression: nodes 927 """ 928 for expression in walk_in_scope(expression, bfs=bfs): 929 if isinstance(expression, tuple(ensure_collection(expression_types))): 930 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
933def find_in_scope(expression, expression_types, bfs=True): 934 """ 935 Returns the first node in this scope which matches at least one of the specified types. 936 937 This does NOT traverse into subscopes. 938 939 Args: 940 expression (exp.Expression): 941 expression_types (tuple[type]|type): the expression type(s) to match. 942 bfs (bool): True to use breadth-first search, False to use depth-first. 943 944 Returns: 945 exp.Expression: the node which matches the criteria or None if no node matching 946 the criteria was found. 947 """ 948 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.