sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._table_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._local_columns = None 102 self._join_hints = None 103 self._pivots = None 104 self._references = None 105 self._semi_anti_join_tables = None 106 self._column_index = None 107 108 def branch( 109 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 110 ): 111 """Branch from the current scope to a new, inner scope""" 112 return Scope( 113 expression=expression.unnest(), 114 sources=sources.copy() if sources else None, 115 parent=self, 116 scope_type=scope_type, 117 cte_sources={**self.cte_sources, **(cte_sources or {})}, 118 lateral_sources=lateral_sources.copy() if lateral_sources else None, 119 can_be_correlated=self.can_be_correlated 120 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 121 **kwargs, 122 ) 123 124 def _collect(self): 125 self._tables = [] 126 self._ctes = [] 127 self._subqueries = [] 128 self._derived_tables = [] 129 self._udtfs = [] 130 self._raw_columns = [] 131 self._table_columns = [] 132 self._stars = [] 133 self._join_hints = [] 134 self._semi_anti_join_tables = set() 135 self._column_index = set() 136 137 for node in self.walk(bfs=False): 138 if node is self.expression: 139 continue 140 141 node_type = type(node) 142 143 if node_type is exp.Dot and node.is_star: 144 self._stars.append(node) 145 elif node_type is exp.Column: 146 self._column_index.add(id(node)) 147 148 if type(node.this) is exp.Star: 149 self._stars.append(node) 150 else: 151 self._raw_columns.append(node) 152 elif node_type is exp.Table and type(node.parent) is not exp.JoinHint: 153 parent = node.parent 154 if type(parent) is exp.Join and parent.is_semi_or_anti_join: 155 self._semi_anti_join_tables.add(node.alias_or_name) 156 157 self._tables.append(node) 158 elif node_type is exp.JoinHint: 159 self._join_hints.append(node) 160 elif isinstance(node, exp.UDTF): 161 self._udtfs.append(node) 162 elif node_type is exp.CTE: 163 self._ctes.append(node) 164 elif _is_derived_table(node) and _is_from_or_join(node): 165 self._derived_tables.append(node) 166 elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node): 167 self._subqueries.append(node) 168 elif node_type is exp.TableColumn: 169 self._table_columns.append(node) 170 171 self._collected = True 172 173 def _ensure_collected(self): 174 if not self._collected: 175 self._collect() 176 177 def walk(self, bfs=True, prune=None): 178 return walk_in_scope(self.expression, bfs=bfs, prune=None) 179 180 def find(self, *expression_types, bfs=True): 181 return find_in_scope(self.expression, expression_types, bfs=bfs) 182 183 def find_all(self, *expression_types, bfs=True): 184 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 185 186 def replace(self, old, new): 187 """ 188 Replace `old` with `new`. 189 190 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 191 192 Args: 193 old (exp.Expression): old node 194 new (exp.Expression): new node 195 """ 196 old.replace(new) 197 self.clear_cache() 198 199 @property 200 def tables(self): 201 """ 202 List of tables in this scope. 203 204 Returns: 205 list[exp.Table]: tables 206 """ 207 self._ensure_collected() 208 return self._tables 209 210 @property 211 def ctes(self): 212 """ 213 List of CTEs in this scope. 214 215 Returns: 216 list[exp.CTE]: ctes 217 """ 218 self._ensure_collected() 219 return self._ctes 220 221 @property 222 def derived_tables(self): 223 """ 224 List of derived tables in this scope. 225 226 For example: 227 SELECT * FROM (SELECT ...) <- that's a derived table 228 229 Returns: 230 list[exp.Subquery]: derived tables 231 """ 232 self._ensure_collected() 233 return self._derived_tables 234 235 @property 236 def udtfs(self): 237 """ 238 List of "User Defined Tabular Functions" in this scope. 239 240 Returns: 241 list[exp.UDTF]: UDTFs 242 """ 243 self._ensure_collected() 244 return self._udtfs 245 246 @property 247 def subqueries(self): 248 """ 249 List of subqueries in this scope. 250 251 For example: 252 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 253 254 Returns: 255 list[exp.Select | exp.SetOperation]: subqueries 256 """ 257 self._ensure_collected() 258 return self._subqueries 259 260 @property 261 def stars(self) -> t.List[exp.Column | exp.Dot]: 262 """ 263 List of star expressions (columns or dots) in this scope. 264 """ 265 self._ensure_collected() 266 return self._stars 267 268 @property 269 def column_index(self) -> t.Set[int]: 270 """ 271 Set of column object IDs that belong to this scope's expression. 272 """ 273 self._ensure_collected() 274 return self._column_index 275 276 @property 277 def columns(self): 278 """ 279 List of columns in this scope. 280 281 Returns: 282 list[exp.Column]: Column instances in this scope, plus any 283 Columns that reference this scope from correlated subqueries. 284 """ 285 if self._columns is None: 286 self._ensure_collected() 287 columns = self._raw_columns 288 289 external_columns = [ 290 column 291 for scope in itertools.chain( 292 self.subquery_scopes, 293 self.udtf_scopes, 294 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 295 ) 296 for column in scope.external_columns 297 ] 298 299 named_selects = set(self.expression.named_selects) 300 301 self._columns = [] 302 for column in columns + external_columns: 303 ancestor = column.find_ancestor( 304 exp.Select, 305 exp.Qualify, 306 exp.Order, 307 exp.Having, 308 exp.Hint, 309 exp.Table, 310 exp.Star, 311 exp.Distinct, 312 ) 313 if ( 314 not ancestor 315 or column.table 316 or isinstance(ancestor, exp.Select) 317 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 318 or ( 319 isinstance(ancestor, (exp.Order, exp.Distinct)) 320 and ( 321 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 322 or not isinstance(ancestor.parent, exp.Select) 323 or column.name not in named_selects 324 ) 325 ) 326 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 327 ): 328 self._columns.append(column) 329 330 return self._columns 331 332 @property 333 def table_columns(self): 334 if self._table_columns is None: 335 self._ensure_collected() 336 337 return self._table_columns 338 339 @property 340 def selected_sources(self): 341 """ 342 Mapping of nodes and sources that are actually selected from in this scope. 343 344 That is, all tables in a schema are selectable at any point. But a 345 table only becomes a selected source if it's included in a FROM or JOIN clause. 346 347 Returns: 348 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 349 """ 350 if self._selected_sources is None: 351 result = {} 352 353 for name, node in self.references: 354 if name in self._semi_anti_join_tables: 355 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 356 # selected source 357 continue 358 359 if name in result: 360 raise OptimizeError(f"Alias already used: {name}") 361 if name in self.sources: 362 result[name] = (node, self.sources[name]) 363 364 self._selected_sources = result 365 return self._selected_sources 366 367 @property 368 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 369 if self._references is None: 370 self._references = [] 371 372 for table in self.tables: 373 self._references.append((table.alias_or_name, table)) 374 for expression in itertools.chain(self.derived_tables, self.udtfs): 375 self._references.append( 376 ( 377 _get_source_alias(expression), 378 expression if expression.args.get("pivots") else expression.unnest(), 379 ) 380 ) 381 382 return self._references 383 384 @property 385 def external_columns(self): 386 """ 387 Columns that appear to reference sources in outer scopes. 388 389 Returns: 390 list[exp.Column]: Column instances that don't reference sources in the current scope. 391 """ 392 if self._external_columns is None: 393 if isinstance(self.expression, exp.SetOperation): 394 left, right = self.union_scopes 395 self._external_columns = left.external_columns + right.external_columns 396 else: 397 self._external_columns = [ 398 c 399 for c in self.columns 400 if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables 401 ] 402 403 return self._external_columns 404 405 @property 406 def local_columns(self): 407 """ 408 Columns in this scope that are not external. 409 410 Returns: 411 list[exp.Column]: Column instances that reference sources in the current scope. 412 """ 413 if self._local_columns is None: 414 external_columns = set(self.external_columns) 415 self._local_columns = [c for c in self.columns if c not in external_columns] 416 417 return self._local_columns 418 419 @property 420 def unqualified_columns(self): 421 """ 422 Unqualified columns in the current scope. 423 424 Returns: 425 list[exp.Column]: Unqualified columns 426 """ 427 return [c for c in self.columns if not c.table] 428 429 @property 430 def join_hints(self): 431 """ 432 Hints that exist in the scope that reference tables 433 434 Returns: 435 list[exp.JoinHint]: Join hints that are referenced within the scope 436 """ 437 if self._join_hints is None: 438 return [] 439 return self._join_hints 440 441 @property 442 def pivots(self): 443 if not self._pivots: 444 self._pivots = [ 445 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 446 ] 447 448 return self._pivots 449 450 @property 451 def semi_or_anti_join_tables(self): 452 return self._semi_anti_join_tables or set() 453 454 def source_columns(self, source_name): 455 """ 456 Get all columns in the current scope for a particular source. 457 458 Args: 459 source_name (str): Name of the source 460 Returns: 461 list[exp.Column]: Column instances that reference `source_name` 462 """ 463 return [column for column in self.columns if column.table == source_name] 464 465 @property 466 def is_subquery(self): 467 """Determine if this scope is a subquery""" 468 return self.scope_type == ScopeType.SUBQUERY 469 470 @property 471 def is_derived_table(self): 472 """Determine if this scope is a derived table""" 473 return self.scope_type == ScopeType.DERIVED_TABLE 474 475 @property 476 def is_union(self): 477 """Determine if this scope is a union""" 478 return self.scope_type == ScopeType.UNION 479 480 @property 481 def is_cte(self): 482 """Determine if this scope is a common table expression""" 483 return self.scope_type == ScopeType.CTE 484 485 @property 486 def is_root(self): 487 """Determine if this is the root scope""" 488 return self.scope_type == ScopeType.ROOT 489 490 @property 491 def is_udtf(self): 492 """Determine if this scope is a UDTF (User Defined Table Function)""" 493 return self.scope_type == ScopeType.UDTF 494 495 @property 496 def is_correlated_subquery(self): 497 """Determine if this scope is a correlated subquery""" 498 return bool(self.can_be_correlated and self.external_columns) 499 500 def rename_source(self, old_name, new_name): 501 """Rename a source in this scope""" 502 old_name = old_name or "" 503 if old_name in self.sources: 504 self.sources[new_name] = self.sources.pop(old_name) 505 506 def add_source(self, name, source): 507 """Add a source to this scope""" 508 self.sources[name] = source 509 self.clear_cache() 510 511 def remove_source(self, name): 512 """Remove a source from this scope""" 513 self.sources.pop(name, None) 514 self.clear_cache() 515 516 def __repr__(self): 517 return f"Scope<{self.expression.sql()}>" 518 519 def traverse(self): 520 """ 521 Traverse the scope tree from this node. 522 523 Yields: 524 Scope: scope instances in depth-first-search post-order 525 """ 526 stack = [self] 527 result = [] 528 while stack: 529 scope = stack.pop() 530 result.append(scope) 531 stack.extend( 532 itertools.chain( 533 scope.cte_scopes, 534 scope.union_scopes, 535 scope.table_scopes, 536 scope.subquery_scopes, 537 ) 538 ) 539 540 yield from reversed(result) 541 542 def ref_count(self): 543 """ 544 Count the number of times each scope in this tree is referenced. 545 546 Returns: 547 dict[int, int]: Mapping of Scope instance ID to reference count 548 """ 549 scope_ref_count = defaultdict(lambda: 0) 550 551 for scope in self.traverse(): 552 for _, source in scope.selected_sources.values(): 553 scope_ref_count[id(source)] += 1 554 555 for name in scope._semi_anti_join_tables: 556 # semi/anti join sources are not actually selected but we still need to 557 # increment their ref count to avoid them being optimized away 558 if name in scope.sources: 559 scope_ref_count[id(scope.sources[name])] += 1 560 561 return scope_ref_count 562 563 564def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 565 """ 566 Traverse an expression by its "scopes". 567 568 "Scope" represents the current context of a Select statement. 569 570 This is helpful for optimizing queries, where we need more information than 571 the expression tree itself. For example, we might care about the source 572 names within a subquery. Returns a list because a generator could result in 573 incomplete properties which is confusing. 574 575 Examples: 576 >>> import sqlglot 577 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 578 >>> scopes = traverse_scope(expression) 579 >>> scopes[0].expression.sql(), list(scopes[0].sources) 580 ('SELECT a FROM x', ['x']) 581 >>> scopes[1].expression.sql(), list(scopes[1].sources) 582 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 583 584 Args: 585 expression: Expression to traverse 586 587 Returns: 588 A list of the created scope instances 589 """ 590 if isinstance(expression, TRAVERSABLES): 591 return list(_traverse_scope(Scope(expression))) 592 return [] 593 594 595def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 596 """ 597 Build a scope tree. 598 599 Args: 600 expression: Expression to build the scope tree for. 601 602 Returns: 603 The root scope 604 """ 605 return seq_get(traverse_scope(expression), -1) 606 607 608def _traverse_scope(scope): 609 expression = scope.expression 610 611 if isinstance(expression, exp.Select): 612 yield from _traverse_select(scope) 613 elif isinstance(expression, exp.SetOperation): 614 yield from _traverse_ctes(scope) 615 yield from _traverse_union(scope) 616 return 617 elif isinstance(expression, exp.Subquery): 618 if scope.is_root: 619 yield from _traverse_select(scope) 620 else: 621 yield from _traverse_subqueries(scope) 622 elif isinstance(expression, exp.Table): 623 yield from _traverse_tables(scope) 624 elif isinstance(expression, exp.UDTF): 625 yield from _traverse_udtfs(scope) 626 elif isinstance(expression, exp.DDL): 627 if isinstance(expression.expression, exp.Query): 628 yield from _traverse_ctes(scope) 629 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 630 return 631 elif isinstance(expression, exp.DML): 632 yield from _traverse_ctes(scope) 633 for query in find_all_in_scope(expression, exp.Query): 634 # This check ensures we don't yield the CTE/nested queries twice 635 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 636 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 637 return 638 else: 639 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 640 return 641 642 yield scope 643 644 645def _traverse_select(scope): 646 yield from _traverse_ctes(scope) 647 yield from _traverse_tables(scope) 648 yield from _traverse_subqueries(scope) 649 650 651def _traverse_union(scope): 652 prev_scope = None 653 union_scope_stack = [scope] 654 expression_stack = [scope.expression.right, scope.expression.left] 655 656 while expression_stack: 657 expression = expression_stack.pop() 658 union_scope = union_scope_stack[-1] 659 660 new_scope = union_scope.branch( 661 expression, 662 outer_columns=union_scope.outer_columns, 663 scope_type=ScopeType.UNION, 664 ) 665 666 if isinstance(expression, exp.SetOperation): 667 yield from _traverse_ctes(new_scope) 668 669 union_scope_stack.append(new_scope) 670 expression_stack.extend([expression.right, expression.left]) 671 continue 672 673 for scope in _traverse_scope(new_scope): 674 yield scope 675 676 if prev_scope: 677 union_scope_stack.pop() 678 union_scope.union_scopes = [prev_scope, scope] 679 prev_scope = union_scope 680 681 yield union_scope 682 else: 683 prev_scope = scope 684 685 686def _traverse_ctes(scope): 687 sources = {} 688 689 for cte in scope.ctes: 690 cte_name = cte.alias 691 692 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 693 # thus the recursive scope is the first section of the union. 694 with_ = scope.expression.args.get("with_") 695 if with_ and with_.recursive: 696 union = cte.this 697 698 if isinstance(union, exp.SetOperation): 699 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 700 701 child_scope = None 702 703 for child_scope in _traverse_scope( 704 scope.branch( 705 cte.this, 706 cte_sources=sources, 707 outer_columns=cte.alias_column_names, 708 scope_type=ScopeType.CTE, 709 ) 710 ): 711 yield child_scope 712 713 # append the final child_scope yielded 714 if child_scope: 715 sources[cte_name] = child_scope 716 scope.cte_scopes.append(child_scope) 717 718 scope.sources.update(sources) 719 scope.cte_sources.update(sources) 720 721 722def _is_derived_table(expression: exp.Subquery) -> bool: 723 """ 724 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 725 as it doesn't introduce a new scope. If an alias is present, it shadows all names 726 under the Subquery, so that's one exception to this rule. 727 """ 728 return type(expression) is exp.Subquery and bool( 729 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 730 ) 731 732 733def _is_from_or_join(expression: exp.Expression) -> bool: 734 """ 735 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 736 """ 737 parent = expression.parent 738 739 # Subqueries can be arbitrarily nested 740 while type(parent) is exp.Subquery: 741 parent = parent.parent 742 743 return type(parent) in (exp.From, exp.Join) 744 745 746def _traverse_tables(scope): 747 sources = {} 748 749 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 750 expressions = [] 751 from_ = scope.expression.args.get("from_") 752 if from_: 753 expressions.append(from_.this) 754 755 for join in scope.expression.args.get("joins") or []: 756 expressions.append(join.this) 757 758 if isinstance(scope.expression, exp.Table): 759 expressions.append(scope.expression) 760 761 expressions.extend(scope.expression.args.get("laterals") or []) 762 763 for expression in expressions: 764 if isinstance(expression, exp.Final): 765 expression = expression.this 766 if isinstance(expression, exp.Table): 767 table_name = expression.name 768 source_name = expression.alias_or_name 769 770 if table_name in scope.sources and not expression.db: 771 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 772 # it is pivoted, because then we get back a new table and hence a new source. 773 pivots = expression.args.get("pivots") 774 if pivots: 775 sources[pivots[0].alias] = expression 776 else: 777 sources[source_name] = scope.sources[table_name] 778 elif source_name in sources: 779 sources[find_new_name(sources, table_name)] = expression 780 else: 781 sources[source_name] = expression 782 783 # Make sure to not include the joins twice 784 if expression is not scope.expression: 785 expressions.extend(join.this for join in expression.args.get("joins") or []) 786 787 continue 788 789 if not isinstance(expression, exp.DerivedTable): 790 continue 791 792 if isinstance(expression, exp.UDTF): 793 lateral_sources = sources 794 scope_type = ScopeType.UDTF 795 scopes = scope.udtf_scopes 796 elif _is_derived_table(expression): 797 lateral_sources = None 798 scope_type = ScopeType.DERIVED_TABLE 799 scopes = scope.derived_table_scopes 800 expressions.extend(join.this for join in expression.args.get("joins") or []) 801 else: 802 # Makes sure we check for possible sources in nested table constructs 803 expressions.append(expression.this) 804 expressions.extend(join.this for join in expression.args.get("joins") or []) 805 continue 806 807 child_scope = None 808 809 for child_scope in _traverse_scope( 810 scope.branch( 811 expression, 812 lateral_sources=lateral_sources, 813 outer_columns=expression.alias_column_names, 814 scope_type=scope_type, 815 ) 816 ): 817 yield child_scope 818 819 # Tables without aliases will be set as "" 820 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 821 # Until then, this means that only a single, unaliased derived table is allowed (rather, 822 # the latest one wins. 823 sources[_get_source_alias(expression)] = child_scope 824 825 # append the final child_scope yielded 826 if child_scope: 827 scopes.append(child_scope) 828 scope.table_scopes.append(child_scope) 829 830 scope.sources.update(sources) 831 832 833def _traverse_subqueries(scope): 834 for subquery in scope.subqueries: 835 top = None 836 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 837 yield child_scope 838 top = child_scope 839 scope.subquery_scopes.append(top) 840 841 842def _traverse_udtfs(scope): 843 if isinstance(scope.expression, exp.Unnest): 844 expressions = scope.expression.expressions 845 elif isinstance(scope.expression, exp.Lateral): 846 expressions = [scope.expression.this] 847 else: 848 expressions = [] 849 850 sources = {} 851 for expression in expressions: 852 if isinstance(expression, exp.Subquery): 853 top = None 854 for child_scope in _traverse_scope( 855 scope.branch( 856 expression, 857 scope_type=ScopeType.SUBQUERY, 858 outer_columns=expression.alias_column_names, 859 ) 860 ): 861 yield child_scope 862 top = child_scope 863 sources[_get_source_alias(expression)] = child_scope 864 865 scope.subquery_scopes.append(top) 866 867 scope.sources.update(sources) 868 869 870def walk_in_scope(expression, bfs=True, prune=None): 871 """ 872 Returns a generator object which visits all nodes in the syntrax tree, stopping at 873 nodes that start child scopes. 874 875 Args: 876 expression (exp.Expression): 877 bfs (bool): if set to True the BFS traversal order will be applied, 878 otherwise the DFS traversal will be used instead. 879 prune ((node, parent, arg_key) -> bool): callable that returns True if 880 the generator should stop traversing this branch of the tree. 881 882 Yields: 883 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 884 """ 885 # We'll use this variable to pass state into the dfs generator. 886 # Whenever we set it to True, we exclude a subtree from traversal. 887 crossed_scope_boundary = False 888 889 for node in expression.walk( 890 bfs=bfs, prune=lambda n: bool(crossed_scope_boundary or (prune and prune(n))) 891 ): 892 crossed_scope_boundary = False 893 894 yield node 895 896 if node is expression: 897 continue 898 899 node_type = type(node) 900 parent_type = type(node.parent) 901 if ( 902 node_type is exp.CTE 903 or (parent_type in (exp.From, exp.Join) and _is_derived_table(node)) 904 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 905 or isinstance(node, exp.UNWRAPPED_QUERIES) 906 ): 907 crossed_scope_boundary = True 908 909 if node_type is exp.Subquery or isinstance(node, exp.UDTF): 910 # The following args are not actually in the inner scope, so we should visit them 911 for key in ("joins", "laterals", "pivots"): 912 for arg in node.args.get(key) or []: 913 yield from walk_in_scope(arg, bfs=bfs) 914 915 916def find_all_in_scope(expression, expression_types, bfs=True): 917 """ 918 Returns a generator object which visits all nodes in this scope and only yields those that 919 match at least one of the specified expression types. 920 921 This does NOT traverse into subscopes. 922 923 Args: 924 expression (exp.Expression): 925 expression_types (tuple[type]|type): the expression type(s) to match. 926 bfs (bool): True to use breadth-first search, False to use depth-first. 927 928 Yields: 929 exp.Expression: nodes 930 """ 931 for expression in walk_in_scope(expression, bfs=bfs): 932 if isinstance(expression, tuple(ensure_collection(expression_types))): 933 yield expression 934 935 936def find_in_scope(expression, expression_types, bfs=True): 937 """ 938 Returns the first node in this scope which matches at least one of the specified types. 939 940 This does NOT traverse into subscopes. 941 942 Args: 943 expression (exp.Expression): 944 expression_types (tuple[type]|type): the expression type(s) to match. 945 bfs (bool): True to use breadth-first search, False to use depth-first. 946 947 Returns: 948 exp.Expression: the node which matches the criteria or None if no node matching 949 the criteria was found. 950 """ 951 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) 952 953 954def _get_source_alias(expression): 955 alias_arg = expression.args.get("alias") 956 alias_name = expression.alias 957 958 if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1: 959 alias_name = alias_arg.columns[0].name 960 961 return alias_name
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._local_columns = None 103 self._join_hints = None 104 self._pivots = None 105 self._references = None 106 self._semi_anti_join_tables = None 107 self._column_index = None 108 109 def branch( 110 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 111 ): 112 """Branch from the current scope to a new, inner scope""" 113 return Scope( 114 expression=expression.unnest(), 115 sources=sources.copy() if sources else None, 116 parent=self, 117 scope_type=scope_type, 118 cte_sources={**self.cte_sources, **(cte_sources or {})}, 119 lateral_sources=lateral_sources.copy() if lateral_sources else None, 120 can_be_correlated=self.can_be_correlated 121 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 122 **kwargs, 123 ) 124 125 def _collect(self): 126 self._tables = [] 127 self._ctes = [] 128 self._subqueries = [] 129 self._derived_tables = [] 130 self._udtfs = [] 131 self._raw_columns = [] 132 self._table_columns = [] 133 self._stars = [] 134 self._join_hints = [] 135 self._semi_anti_join_tables = set() 136 self._column_index = set() 137 138 for node in self.walk(bfs=False): 139 if node is self.expression: 140 continue 141 142 node_type = type(node) 143 144 if node_type is exp.Dot and node.is_star: 145 self._stars.append(node) 146 elif node_type is exp.Column: 147 self._column_index.add(id(node)) 148 149 if type(node.this) is exp.Star: 150 self._stars.append(node) 151 else: 152 self._raw_columns.append(node) 153 elif node_type is exp.Table and type(node.parent) is not exp.JoinHint: 154 parent = node.parent 155 if type(parent) is exp.Join and parent.is_semi_or_anti_join: 156 self._semi_anti_join_tables.add(node.alias_or_name) 157 158 self._tables.append(node) 159 elif node_type is exp.JoinHint: 160 self._join_hints.append(node) 161 elif isinstance(node, exp.UDTF): 162 self._udtfs.append(node) 163 elif node_type is exp.CTE: 164 self._ctes.append(node) 165 elif _is_derived_table(node) and _is_from_or_join(node): 166 self._derived_tables.append(node) 167 elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node): 168 self._subqueries.append(node) 169 elif node_type is exp.TableColumn: 170 self._table_columns.append(node) 171 172 self._collected = True 173 174 def _ensure_collected(self): 175 if not self._collected: 176 self._collect() 177 178 def walk(self, bfs=True, prune=None): 179 return walk_in_scope(self.expression, bfs=bfs, prune=None) 180 181 def find(self, *expression_types, bfs=True): 182 return find_in_scope(self.expression, expression_types, bfs=bfs) 183 184 def find_all(self, *expression_types, bfs=True): 185 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 186 187 def replace(self, old, new): 188 """ 189 Replace `old` with `new`. 190 191 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 192 193 Args: 194 old (exp.Expression): old node 195 new (exp.Expression): new node 196 """ 197 old.replace(new) 198 self.clear_cache() 199 200 @property 201 def tables(self): 202 """ 203 List of tables in this scope. 204 205 Returns: 206 list[exp.Table]: tables 207 """ 208 self._ensure_collected() 209 return self._tables 210 211 @property 212 def ctes(self): 213 """ 214 List of CTEs in this scope. 215 216 Returns: 217 list[exp.CTE]: ctes 218 """ 219 self._ensure_collected() 220 return self._ctes 221 222 @property 223 def derived_tables(self): 224 """ 225 List of derived tables in this scope. 226 227 For example: 228 SELECT * FROM (SELECT ...) <- that's a derived table 229 230 Returns: 231 list[exp.Subquery]: derived tables 232 """ 233 self._ensure_collected() 234 return self._derived_tables 235 236 @property 237 def udtfs(self): 238 """ 239 List of "User Defined Tabular Functions" in this scope. 240 241 Returns: 242 list[exp.UDTF]: UDTFs 243 """ 244 self._ensure_collected() 245 return self._udtfs 246 247 @property 248 def subqueries(self): 249 """ 250 List of subqueries in this scope. 251 252 For example: 253 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 254 255 Returns: 256 list[exp.Select | exp.SetOperation]: subqueries 257 """ 258 self._ensure_collected() 259 return self._subqueries 260 261 @property 262 def stars(self) -> t.List[exp.Column | exp.Dot]: 263 """ 264 List of star expressions (columns or dots) in this scope. 265 """ 266 self._ensure_collected() 267 return self._stars 268 269 @property 270 def column_index(self) -> t.Set[int]: 271 """ 272 Set of column object IDs that belong to this scope's expression. 273 """ 274 self._ensure_collected() 275 return self._column_index 276 277 @property 278 def columns(self): 279 """ 280 List of columns in this scope. 281 282 Returns: 283 list[exp.Column]: Column instances in this scope, plus any 284 Columns that reference this scope from correlated subqueries. 285 """ 286 if self._columns is None: 287 self._ensure_collected() 288 columns = self._raw_columns 289 290 external_columns = [ 291 column 292 for scope in itertools.chain( 293 self.subquery_scopes, 294 self.udtf_scopes, 295 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 296 ) 297 for column in scope.external_columns 298 ] 299 300 named_selects = set(self.expression.named_selects) 301 302 self._columns = [] 303 for column in columns + external_columns: 304 ancestor = column.find_ancestor( 305 exp.Select, 306 exp.Qualify, 307 exp.Order, 308 exp.Having, 309 exp.Hint, 310 exp.Table, 311 exp.Star, 312 exp.Distinct, 313 ) 314 if ( 315 not ancestor 316 or column.table 317 or isinstance(ancestor, exp.Select) 318 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 319 or ( 320 isinstance(ancestor, (exp.Order, exp.Distinct)) 321 and ( 322 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 323 or not isinstance(ancestor.parent, exp.Select) 324 or column.name not in named_selects 325 ) 326 ) 327 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 328 ): 329 self._columns.append(column) 330 331 return self._columns 332 333 @property 334 def table_columns(self): 335 if self._table_columns is None: 336 self._ensure_collected() 337 338 return self._table_columns 339 340 @property 341 def selected_sources(self): 342 """ 343 Mapping of nodes and sources that are actually selected from in this scope. 344 345 That is, all tables in a schema are selectable at any point. But a 346 table only becomes a selected source if it's included in a FROM or JOIN clause. 347 348 Returns: 349 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 350 """ 351 if self._selected_sources is None: 352 result = {} 353 354 for name, node in self.references: 355 if name in self._semi_anti_join_tables: 356 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 357 # selected source 358 continue 359 360 if name in result: 361 raise OptimizeError(f"Alias already used: {name}") 362 if name in self.sources: 363 result[name] = (node, self.sources[name]) 364 365 self._selected_sources = result 366 return self._selected_sources 367 368 @property 369 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 370 if self._references is None: 371 self._references = [] 372 373 for table in self.tables: 374 self._references.append((table.alias_or_name, table)) 375 for expression in itertools.chain(self.derived_tables, self.udtfs): 376 self._references.append( 377 ( 378 _get_source_alias(expression), 379 expression if expression.args.get("pivots") else expression.unnest(), 380 ) 381 ) 382 383 return self._references 384 385 @property 386 def external_columns(self): 387 """ 388 Columns that appear to reference sources in outer scopes. 389 390 Returns: 391 list[exp.Column]: Column instances that don't reference sources in the current scope. 392 """ 393 if self._external_columns is None: 394 if isinstance(self.expression, exp.SetOperation): 395 left, right = self.union_scopes 396 self._external_columns = left.external_columns + right.external_columns 397 else: 398 self._external_columns = [ 399 c 400 for c in self.columns 401 if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables 402 ] 403 404 return self._external_columns 405 406 @property 407 def local_columns(self): 408 """ 409 Columns in this scope that are not external. 410 411 Returns: 412 list[exp.Column]: Column instances that reference sources in the current scope. 413 """ 414 if self._local_columns is None: 415 external_columns = set(self.external_columns) 416 self._local_columns = [c for c in self.columns if c not in external_columns] 417 418 return self._local_columns 419 420 @property 421 def unqualified_columns(self): 422 """ 423 Unqualified columns in the current scope. 424 425 Returns: 426 list[exp.Column]: Unqualified columns 427 """ 428 return [c for c in self.columns if not c.table] 429 430 @property 431 def join_hints(self): 432 """ 433 Hints that exist in the scope that reference tables 434 435 Returns: 436 list[exp.JoinHint]: Join hints that are referenced within the scope 437 """ 438 if self._join_hints is None: 439 return [] 440 return self._join_hints 441 442 @property 443 def pivots(self): 444 if not self._pivots: 445 self._pivots = [ 446 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 447 ] 448 449 return self._pivots 450 451 @property 452 def semi_or_anti_join_tables(self): 453 return self._semi_anti_join_tables or set() 454 455 def source_columns(self, source_name): 456 """ 457 Get all columns in the current scope for a particular source. 458 459 Args: 460 source_name (str): Name of the source 461 Returns: 462 list[exp.Column]: Column instances that reference `source_name` 463 """ 464 return [column for column in self.columns if column.table == source_name] 465 466 @property 467 def is_subquery(self): 468 """Determine if this scope is a subquery""" 469 return self.scope_type == ScopeType.SUBQUERY 470 471 @property 472 def is_derived_table(self): 473 """Determine if this scope is a derived table""" 474 return self.scope_type == ScopeType.DERIVED_TABLE 475 476 @property 477 def is_union(self): 478 """Determine if this scope is a union""" 479 return self.scope_type == ScopeType.UNION 480 481 @property 482 def is_cte(self): 483 """Determine if this scope is a common table expression""" 484 return self.scope_type == ScopeType.CTE 485 486 @property 487 def is_root(self): 488 """Determine if this is the root scope""" 489 return self.scope_type == ScopeType.ROOT 490 491 @property 492 def is_udtf(self): 493 """Determine if this scope is a UDTF (User Defined Table Function)""" 494 return self.scope_type == ScopeType.UDTF 495 496 @property 497 def is_correlated_subquery(self): 498 """Determine if this scope is a correlated subquery""" 499 return bool(self.can_be_correlated and self.external_columns) 500 501 def rename_source(self, old_name, new_name): 502 """Rename a source in this scope""" 503 old_name = old_name or "" 504 if old_name in self.sources: 505 self.sources[new_name] = self.sources.pop(old_name) 506 507 def add_source(self, name, source): 508 """Add a source to this scope""" 509 self.sources[name] = source 510 self.clear_cache() 511 512 def remove_source(self, name): 513 """Remove a source from this scope""" 514 self.sources.pop(name, None) 515 self.clear_cache() 516 517 def __repr__(self): 518 return f"Scope<{self.expression.sql()}>" 519 520 def traverse(self): 521 """ 522 Traverse the scope tree from this node. 523 524 Yields: 525 Scope: scope instances in depth-first-search post-order 526 """ 527 stack = [self] 528 result = [] 529 while stack: 530 scope = stack.pop() 531 result.append(scope) 532 stack.extend( 533 itertools.chain( 534 scope.cte_scopes, 535 scope.union_scopes, 536 scope.table_scopes, 537 scope.subquery_scopes, 538 ) 539 ) 540 541 yield from reversed(result) 542 543 def ref_count(self): 544 """ 545 Count the number of times each scope in this tree is referenced. 546 547 Returns: 548 dict[int, int]: Mapping of Scope instance ID to reference count 549 """ 550 scope_ref_count = defaultdict(lambda: 0) 551 552 for scope in self.traverse(): 553 for _, source in scope.selected_sources.values(): 554 scope_ref_count[id(source)] += 1 555 556 for name in scope._semi_anti_join_tables: 557 # semi/anti join sources are not actually selected but we still need to 558 # increment their ref count to avoid them being optimized away 559 if name in scope.sources: 560 scope_ref_count[id(scope.sources[name])] += 1 561 562 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]for itsouter_columns - parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._local_columns = None 103 self._join_hints = None 104 self._pivots = None 105 self._references = None 106 self._semi_anti_join_tables = None 107 self._column_index = None
109 def branch( 110 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 111 ): 112 """Branch from the current scope to a new, inner scope""" 113 return Scope( 114 expression=expression.unnest(), 115 sources=sources.copy() if sources else None, 116 parent=self, 117 scope_type=scope_type, 118 cte_sources={**self.cte_sources, **(cte_sources or {})}, 119 lateral_sources=lateral_sources.copy() if lateral_sources else None, 120 can_be_correlated=self.can_be_correlated 121 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 122 **kwargs, 123 )
Branch from the current scope to a new, inner scope
187 def replace(self, old, new): 188 """ 189 Replace `old` with `new`. 190 191 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 192 193 Args: 194 old (exp.Expression): old node 195 new (exp.Expression): new node 196 """ 197 old.replace(new) 198 self.clear_cache()
Replace old with new.
This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
200 @property 201 def tables(self): 202 """ 203 List of tables in this scope. 204 205 Returns: 206 list[exp.Table]: tables 207 """ 208 self._ensure_collected() 209 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
211 @property 212 def ctes(self): 213 """ 214 List of CTEs in this scope. 215 216 Returns: 217 list[exp.CTE]: ctes 218 """ 219 self._ensure_collected() 220 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
222 @property 223 def derived_tables(self): 224 """ 225 List of derived tables in this scope. 226 227 For example: 228 SELECT * FROM (SELECT ...) <- that's a derived table 229 230 Returns: 231 list[exp.Subquery]: derived tables 232 """ 233 self._ensure_collected() 234 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
236 @property 237 def udtfs(self): 238 """ 239 List of "User Defined Tabular Functions" in this scope. 240 241 Returns: 242 list[exp.UDTF]: UDTFs 243 """ 244 self._ensure_collected() 245 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
247 @property 248 def subqueries(self): 249 """ 250 List of subqueries in this scope. 251 252 For example: 253 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 254 255 Returns: 256 list[exp.Select | exp.SetOperation]: subqueries 257 """ 258 self._ensure_collected() 259 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
261 @property 262 def stars(self) -> t.List[exp.Column | exp.Dot]: 263 """ 264 List of star expressions (columns or dots) in this scope. 265 """ 266 self._ensure_collected() 267 return self._stars
List of star expressions (columns or dots) in this scope.
269 @property 270 def column_index(self) -> t.Set[int]: 271 """ 272 Set of column object IDs that belong to this scope's expression. 273 """ 274 self._ensure_collected() 275 return self._column_index
Set of column object IDs that belong to this scope's expression.
277 @property 278 def columns(self): 279 """ 280 List of columns in this scope. 281 282 Returns: 283 list[exp.Column]: Column instances in this scope, plus any 284 Columns that reference this scope from correlated subqueries. 285 """ 286 if self._columns is None: 287 self._ensure_collected() 288 columns = self._raw_columns 289 290 external_columns = [ 291 column 292 for scope in itertools.chain( 293 self.subquery_scopes, 294 self.udtf_scopes, 295 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 296 ) 297 for column in scope.external_columns 298 ] 299 300 named_selects = set(self.expression.named_selects) 301 302 self._columns = [] 303 for column in columns + external_columns: 304 ancestor = column.find_ancestor( 305 exp.Select, 306 exp.Qualify, 307 exp.Order, 308 exp.Having, 309 exp.Hint, 310 exp.Table, 311 exp.Star, 312 exp.Distinct, 313 ) 314 if ( 315 not ancestor 316 or column.table 317 or isinstance(ancestor, exp.Select) 318 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 319 or ( 320 isinstance(ancestor, (exp.Order, exp.Distinct)) 321 and ( 322 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 323 or not isinstance(ancestor.parent, exp.Select) 324 or column.name not in named_selects 325 ) 326 ) 327 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") 328 ): 329 self._columns.append(column) 330 331 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
340 @property 341 def selected_sources(self): 342 """ 343 Mapping of nodes and sources that are actually selected from in this scope. 344 345 That is, all tables in a schema are selectable at any point. But a 346 table only becomes a selected source if it's included in a FROM or JOIN clause. 347 348 Returns: 349 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 350 """ 351 if self._selected_sources is None: 352 result = {} 353 354 for name, node in self.references: 355 if name in self._semi_anti_join_tables: 356 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 357 # selected source 358 continue 359 360 if name in result: 361 raise OptimizeError(f"Alias already used: {name}") 362 if name in self.sources: 363 result[name] = (node, self.sources[name]) 364 365 self._selected_sources = result 366 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
368 @property 369 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 370 if self._references is None: 371 self._references = [] 372 373 for table in self.tables: 374 self._references.append((table.alias_or_name, table)) 375 for expression in itertools.chain(self.derived_tables, self.udtfs): 376 self._references.append( 377 ( 378 _get_source_alias(expression), 379 expression if expression.args.get("pivots") else expression.unnest(), 380 ) 381 ) 382 383 return self._references
385 @property 386 def external_columns(self): 387 """ 388 Columns that appear to reference sources in outer scopes. 389 390 Returns: 391 list[exp.Column]: Column instances that don't reference sources in the current scope. 392 """ 393 if self._external_columns is None: 394 if isinstance(self.expression, exp.SetOperation): 395 left, right = self.union_scopes 396 self._external_columns = left.external_columns + right.external_columns 397 else: 398 self._external_columns = [ 399 c 400 for c in self.columns 401 if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables 402 ] 403 404 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
406 @property 407 def local_columns(self): 408 """ 409 Columns in this scope that are not external. 410 411 Returns: 412 list[exp.Column]: Column instances that reference sources in the current scope. 413 """ 414 if self._local_columns is None: 415 external_columns = set(self.external_columns) 416 self._local_columns = [c for c in self.columns if c not in external_columns] 417 418 return self._local_columns
Columns in this scope that are not external.
Returns:
list[exp.Column]: Column instances that reference sources in the current scope.
420 @property 421 def unqualified_columns(self): 422 """ 423 Unqualified columns in the current scope. 424 425 Returns: 426 list[exp.Column]: Unqualified columns 427 """ 428 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
430 @property 431 def join_hints(self): 432 """ 433 Hints that exist in the scope that reference tables 434 435 Returns: 436 list[exp.JoinHint]: Join hints that are referenced within the scope 437 """ 438 if self._join_hints is None: 439 return [] 440 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
455 def source_columns(self, source_name): 456 """ 457 Get all columns in the current scope for a particular source. 458 459 Args: 460 source_name (str): Name of the source 461 Returns: 462 list[exp.Column]: Column instances that reference `source_name` 463 """ 464 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
466 @property 467 def is_subquery(self): 468 """Determine if this scope is a subquery""" 469 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
471 @property 472 def is_derived_table(self): 473 """Determine if this scope is a derived table""" 474 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
476 @property 477 def is_union(self): 478 """Determine if this scope is a union""" 479 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
481 @property 482 def is_cte(self): 483 """Determine if this scope is a common table expression""" 484 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
486 @property 487 def is_root(self): 488 """Determine if this is the root scope""" 489 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
491 @property 492 def is_udtf(self): 493 """Determine if this scope is a UDTF (User Defined Table Function)""" 494 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
501 def rename_source(self, old_name, new_name): 502 """Rename a source in this scope""" 503 old_name = old_name or "" 504 if old_name in self.sources: 505 self.sources[new_name] = self.sources.pop(old_name)
Rename a source in this scope
507 def add_source(self, name, source): 508 """Add a source to this scope""" 509 self.sources[name] = source 510 self.clear_cache()
Add a source to this scope
512 def remove_source(self, name): 513 """Remove a source from this scope""" 514 self.sources.pop(name, None) 515 self.clear_cache()
Remove a source from this scope
520 def traverse(self): 521 """ 522 Traverse the scope tree from this node. 523 524 Yields: 525 Scope: scope instances in depth-first-search post-order 526 """ 527 stack = [self] 528 result = [] 529 while stack: 530 scope = stack.pop() 531 result.append(scope) 532 stack.extend( 533 itertools.chain( 534 scope.cte_scopes, 535 scope.union_scopes, 536 scope.table_scopes, 537 scope.subquery_scopes, 538 ) 539 ) 540 541 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
543 def ref_count(self): 544 """ 545 Count the number of times each scope in this tree is referenced. 546 547 Returns: 548 dict[int, int]: Mapping of Scope instance ID to reference count 549 """ 550 scope_ref_count = defaultdict(lambda: 0) 551 552 for scope in self.traverse(): 553 for _, source in scope.selected_sources.values(): 554 scope_ref_count[id(source)] += 1 555 556 for name in scope._semi_anti_join_tables: 557 # semi/anti join sources are not actually selected but we still need to 558 # increment their ref count to avoid them being optimized away 559 if name in scope.sources: 560 scope_ref_count[id(scope.sources[name])] += 1 561 562 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
565def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 566 """ 567 Traverse an expression by its "scopes". 568 569 "Scope" represents the current context of a Select statement. 570 571 This is helpful for optimizing queries, where we need more information than 572 the expression tree itself. For example, we might care about the source 573 names within a subquery. Returns a list because a generator could result in 574 incomplete properties which is confusing. 575 576 Examples: 577 >>> import sqlglot 578 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 579 >>> scopes = traverse_scope(expression) 580 >>> scopes[0].expression.sql(), list(scopes[0].sources) 581 ('SELECT a FROM x', ['x']) 582 >>> scopes[1].expression.sql(), list(scopes[1].sources) 583 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 584 585 Args: 586 expression: Expression to traverse 587 588 Returns: 589 A list of the created scope instances 590 """ 591 if isinstance(expression, TRAVERSABLES): 592 return list(_traverse_scope(Scope(expression))) 593 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
596def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 597 """ 598 Build a scope tree. 599 600 Args: 601 expression: Expression to build the scope tree for. 602 603 Returns: 604 The root scope 605 """ 606 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
871def walk_in_scope(expression, bfs=True, prune=None): 872 """ 873 Returns a generator object which visits all nodes in the syntrax tree, stopping at 874 nodes that start child scopes. 875 876 Args: 877 expression (exp.Expression): 878 bfs (bool): if set to True the BFS traversal order will be applied, 879 otherwise the DFS traversal will be used instead. 880 prune ((node, parent, arg_key) -> bool): callable that returns True if 881 the generator should stop traversing this branch of the tree. 882 883 Yields: 884 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 885 """ 886 # We'll use this variable to pass state into the dfs generator. 887 # Whenever we set it to True, we exclude a subtree from traversal. 888 crossed_scope_boundary = False 889 890 for node in expression.walk( 891 bfs=bfs, prune=lambda n: bool(crossed_scope_boundary or (prune and prune(n))) 892 ): 893 crossed_scope_boundary = False 894 895 yield node 896 897 if node is expression: 898 continue 899 900 node_type = type(node) 901 parent_type = type(node.parent) 902 if ( 903 node_type is exp.CTE 904 or (parent_type in (exp.From, exp.Join) and _is_derived_table(node)) 905 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 906 or isinstance(node, exp.UNWRAPPED_QUERIES) 907 ): 908 crossed_scope_boundary = True 909 910 if node_type is exp.Subquery or isinstance(node, exp.UDTF): 911 # The following args are not actually in the inner scope, so we should visit them 912 for key in ("joins", "laterals", "pivots"): 913 for arg in node.args.get(key) or []: 914 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
917def find_all_in_scope(expression, expression_types, bfs=True): 918 """ 919 Returns a generator object which visits all nodes in this scope and only yields those that 920 match at least one of the specified expression types. 921 922 This does NOT traverse into subscopes. 923 924 Args: 925 expression (exp.Expression): 926 expression_types (tuple[type]|type): the expression type(s) to match. 927 bfs (bool): True to use breadth-first search, False to use depth-first. 928 929 Yields: 930 exp.Expression: nodes 931 """ 932 for expression in walk_in_scope(expression, bfs=bfs): 933 if isinstance(expression, tuple(ensure_collection(expression_types))): 934 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
937def find_in_scope(expression, expression_types, bfs=True): 938 """ 939 Returns the first node in this scope which matches at least one of the specified types. 940 941 This does NOT traverse into subscopes. 942 943 Args: 944 expression (exp.Expression): 945 expression_types (tuple[type]|type): the expression type(s) to match. 946 bfs (bool): True to use breadth-first search, False to use depth-first. 947 948 Returns: 949 exp.Expression: the node which matches the criteria or None if no node matching 950 the criteria was found. 951 """ 952 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.