sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._table_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 self._semi_anti_join_tables = None 105 106 def branch( 107 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 108 ): 109 """Branch from the current scope to a new, inner scope""" 110 return Scope( 111 expression=expression.unnest(), 112 sources=sources.copy() if sources else None, 113 parent=self, 114 scope_type=scope_type, 115 cte_sources={**self.cte_sources, **(cte_sources or {})}, 116 lateral_sources=lateral_sources.copy() if lateral_sources else None, 117 can_be_correlated=self.can_be_correlated 118 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 119 **kwargs, 120 ) 121 122 def _collect(self): 123 self._tables = [] 124 self._ctes = [] 125 self._subqueries = [] 126 self._derived_tables = [] 127 self._udtfs = [] 128 self._raw_columns = [] 129 self._table_columns = [] 130 self._stars = [] 131 self._join_hints = [] 132 self._semi_anti_join_tables = set() 133 134 for node in self.walk(bfs=False): 135 if node is self.expression: 136 continue 137 138 if isinstance(node, exp.Dot) and node.is_star: 139 self._stars.append(node) 140 elif isinstance(node, exp.Column): 141 if isinstance(node.this, exp.Star): 142 self._stars.append(node) 143 else: 144 self._raw_columns.append(node) 145 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 146 parent = node.parent 147 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 148 self._semi_anti_join_tables.add(node.alias_or_name) 149 150 self._tables.append(node) 151 elif isinstance(node, exp.JoinHint): 152 self._join_hints.append(node) 153 elif isinstance(node, exp.UDTF): 154 self._udtfs.append(node) 155 elif isinstance(node, exp.CTE): 156 self._ctes.append(node) 157 elif _is_derived_table(node) and _is_from_or_join(node): 158 self._derived_tables.append(node) 159 elif isinstance(node, exp.UNWRAPPED_QUERIES): 160 self._subqueries.append(node) 161 elif isinstance(node, exp.TableColumn): 162 self._table_columns.append(node) 163 164 self._collected = True 165 166 def _ensure_collected(self): 167 if not self._collected: 168 self._collect() 169 170 def walk(self, bfs=True, prune=None): 171 return walk_in_scope(self.expression, bfs=bfs, prune=None) 172 173 def find(self, *expression_types, bfs=True): 174 return find_in_scope(self.expression, expression_types, bfs=bfs) 175 176 def find_all(self, *expression_types, bfs=True): 177 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 178 179 def replace(self, old, new): 180 """ 181 Replace `old` with `new`. 182 183 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 184 185 Args: 186 old (exp.Expression): old node 187 new (exp.Expression): new node 188 """ 189 old.replace(new) 190 self.clear_cache() 191 192 @property 193 def tables(self): 194 """ 195 List of tables in this scope. 196 197 Returns: 198 list[exp.Table]: tables 199 """ 200 self._ensure_collected() 201 return self._tables 202 203 @property 204 def ctes(self): 205 """ 206 List of CTEs in this scope. 207 208 Returns: 209 list[exp.CTE]: ctes 210 """ 211 self._ensure_collected() 212 return self._ctes 213 214 @property 215 def derived_tables(self): 216 """ 217 List of derived tables in this scope. 218 219 For example: 220 SELECT * FROM (SELECT ...) <- that's a derived table 221 222 Returns: 223 list[exp.Subquery]: derived tables 224 """ 225 self._ensure_collected() 226 return self._derived_tables 227 228 @property 229 def udtfs(self): 230 """ 231 List of "User Defined Tabular Functions" in this scope. 232 233 Returns: 234 list[exp.UDTF]: UDTFs 235 """ 236 self._ensure_collected() 237 return self._udtfs 238 239 @property 240 def subqueries(self): 241 """ 242 List of subqueries in this scope. 243 244 For example: 245 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 246 247 Returns: 248 list[exp.Select | exp.SetOperation]: subqueries 249 """ 250 self._ensure_collected() 251 return self._subqueries 252 253 @property 254 def stars(self) -> t.List[exp.Column | exp.Dot]: 255 """ 256 List of star expressions (columns or dots) in this scope. 257 """ 258 self._ensure_collected() 259 return self._stars 260 261 @property 262 def columns(self): 263 """ 264 List of columns in this scope. 265 266 Returns: 267 list[exp.Column]: Column instances in this scope, plus any 268 Columns that reference this scope from correlated subqueries. 269 """ 270 if self._columns is None: 271 self._ensure_collected() 272 columns = self._raw_columns 273 274 external_columns = [ 275 column 276 for scope in itertools.chain( 277 self.subquery_scopes, 278 self.udtf_scopes, 279 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 280 ) 281 for column in scope.external_columns 282 ] 283 284 named_selects = set(self.expression.named_selects) 285 286 self._columns = [] 287 for column in columns + external_columns: 288 ancestor = column.find_ancestor( 289 exp.Select, 290 exp.Qualify, 291 exp.Order, 292 exp.Having, 293 exp.Hint, 294 exp.Table, 295 exp.Star, 296 exp.Distinct, 297 ) 298 if ( 299 not ancestor 300 or column.table 301 or isinstance(ancestor, exp.Select) 302 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 303 or ( 304 isinstance(ancestor, (exp.Order, exp.Distinct)) 305 and ( 306 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 307 or not isinstance(ancestor.parent, exp.Select) 308 or column.name not in named_selects 309 ) 310 ) 311 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 312 ): 313 self._columns.append(column) 314 315 return self._columns 316 317 @property 318 def table_columns(self): 319 if self._table_columns is None: 320 self._ensure_collected() 321 322 return self._table_columns 323 324 @property 325 def selected_sources(self): 326 """ 327 Mapping of nodes and sources that are actually selected from in this scope. 328 329 That is, all tables in a schema are selectable at any point. But a 330 table only becomes a selected source if it's included in a FROM or JOIN clause. 331 332 Returns: 333 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 334 """ 335 if self._selected_sources is None: 336 result = {} 337 338 for name, node in self.references: 339 if name in self._semi_anti_join_tables: 340 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 341 # selected source 342 continue 343 344 if name in result: 345 raise OptimizeError(f"Alias already used: {name}") 346 if name in self.sources: 347 result[name] = (node, self.sources[name]) 348 349 self._selected_sources = result 350 return self._selected_sources 351 352 @property 353 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 354 if self._references is None: 355 self._references = [] 356 357 for table in self.tables: 358 self._references.append((table.alias_or_name, table)) 359 for expression in itertools.chain(self.derived_tables, self.udtfs): 360 self._references.append( 361 ( 362 _get_source_alias(expression), 363 expression if expression.args.get("pivots") else expression.unnest(), 364 ) 365 ) 366 367 return self._references 368 369 @property 370 def external_columns(self): 371 """ 372 Columns that appear to reference sources in outer scopes. 373 374 Returns: 375 list[exp.Column]: Column instances that don't reference 376 sources in the current scope. 377 """ 378 if self._external_columns is None: 379 if isinstance(self.expression, exp.SetOperation): 380 left, right = self.union_scopes 381 self._external_columns = left.external_columns + right.external_columns 382 else: 383 self._external_columns = [ 384 c 385 for c in self.columns 386 if c.table not in self.selected_sources 387 and c.table not in self.semi_or_anti_join_tables 388 ] 389 390 return self._external_columns 391 392 @property 393 def unqualified_columns(self): 394 """ 395 Unqualified columns in the current scope. 396 397 Returns: 398 list[exp.Column]: Unqualified columns 399 """ 400 return [c for c in self.columns if not c.table] 401 402 @property 403 def join_hints(self): 404 """ 405 Hints that exist in the scope that reference tables 406 407 Returns: 408 list[exp.JoinHint]: Join hints that are referenced within the scope 409 """ 410 if self._join_hints is None: 411 return [] 412 return self._join_hints 413 414 @property 415 def pivots(self): 416 if not self._pivots: 417 self._pivots = [ 418 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 419 ] 420 421 return self._pivots 422 423 @property 424 def semi_or_anti_join_tables(self): 425 return self._semi_anti_join_tables or set() 426 427 def source_columns(self, source_name): 428 """ 429 Get all columns in the current scope for a particular source. 430 431 Args: 432 source_name (str): Name of the source 433 Returns: 434 list[exp.Column]: Column instances that reference `source_name` 435 """ 436 return [column for column in self.columns if column.table == source_name] 437 438 @property 439 def is_subquery(self): 440 """Determine if this scope is a subquery""" 441 return self.scope_type == ScopeType.SUBQUERY 442 443 @property 444 def is_derived_table(self): 445 """Determine if this scope is a derived table""" 446 return self.scope_type == ScopeType.DERIVED_TABLE 447 448 @property 449 def is_union(self): 450 """Determine if this scope is a union""" 451 return self.scope_type == ScopeType.UNION 452 453 @property 454 def is_cte(self): 455 """Determine if this scope is a common table expression""" 456 return self.scope_type == ScopeType.CTE 457 458 @property 459 def is_root(self): 460 """Determine if this is the root scope""" 461 return self.scope_type == ScopeType.ROOT 462 463 @property 464 def is_udtf(self): 465 """Determine if this scope is a UDTF (User Defined Table Function)""" 466 return self.scope_type == ScopeType.UDTF 467 468 @property 469 def is_correlated_subquery(self): 470 """Determine if this scope is a correlated subquery""" 471 return bool(self.can_be_correlated and self.external_columns) 472 473 def rename_source(self, old_name, new_name): 474 """Rename a source in this scope""" 475 columns = self.sources.pop(old_name or "", []) 476 self.sources[new_name] = columns 477 478 def add_source(self, name, source): 479 """Add a source to this scope""" 480 self.sources[name] = source 481 self.clear_cache() 482 483 def remove_source(self, name): 484 """Remove a source from this scope""" 485 self.sources.pop(name, None) 486 self.clear_cache() 487 488 def __repr__(self): 489 return f"Scope<{self.expression.sql()}>" 490 491 def traverse(self): 492 """ 493 Traverse the scope tree from this node. 494 495 Yields: 496 Scope: scope instances in depth-first-search post-order 497 """ 498 stack = [self] 499 result = [] 500 while stack: 501 scope = stack.pop() 502 result.append(scope) 503 stack.extend( 504 itertools.chain( 505 scope.cte_scopes, 506 scope.union_scopes, 507 scope.table_scopes, 508 scope.subquery_scopes, 509 ) 510 ) 511 512 yield from reversed(result) 513 514 def ref_count(self): 515 """ 516 Count the number of times each scope in this tree is referenced. 517 518 Returns: 519 dict[int, int]: Mapping of Scope instance ID to reference count 520 """ 521 scope_ref_count = defaultdict(lambda: 0) 522 523 for scope in self.traverse(): 524 for _, source in scope.selected_sources.values(): 525 scope_ref_count[id(source)] += 1 526 527 for name in scope._semi_anti_join_tables: 528 # semi/anti join sources are not actually selected but we still need to 529 # increment their ref count to avoid them being optimized away 530 if name in scope.sources: 531 scope_ref_count[id(scope.sources[name])] += 1 532 533 return scope_ref_count 534 535 536def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 537 """ 538 Traverse an expression by its "scopes". 539 540 "Scope" represents the current context of a Select statement. 541 542 This is helpful for optimizing queries, where we need more information than 543 the expression tree itself. For example, we might care about the source 544 names within a subquery. Returns a list because a generator could result in 545 incomplete properties which is confusing. 546 547 Examples: 548 >>> import sqlglot 549 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 550 >>> scopes = traverse_scope(expression) 551 >>> scopes[0].expression.sql(), list(scopes[0].sources) 552 ('SELECT a FROM x', ['x']) 553 >>> scopes[1].expression.sql(), list(scopes[1].sources) 554 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 555 556 Args: 557 expression: Expression to traverse 558 559 Returns: 560 A list of the created scope instances 561 """ 562 if isinstance(expression, TRAVERSABLES): 563 return list(_traverse_scope(Scope(expression))) 564 return [] 565 566 567def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 568 """ 569 Build a scope tree. 570 571 Args: 572 expression: Expression to build the scope tree for. 573 574 Returns: 575 The root scope 576 """ 577 return seq_get(traverse_scope(expression), -1) 578 579 580def _traverse_scope(scope): 581 expression = scope.expression 582 583 if isinstance(expression, exp.Select): 584 yield from _traverse_select(scope) 585 elif isinstance(expression, exp.SetOperation): 586 yield from _traverse_ctes(scope) 587 yield from _traverse_union(scope) 588 return 589 elif isinstance(expression, exp.Subquery): 590 if scope.is_root: 591 yield from _traverse_select(scope) 592 else: 593 yield from _traverse_subqueries(scope) 594 elif isinstance(expression, exp.Table): 595 yield from _traverse_tables(scope) 596 elif isinstance(expression, exp.UDTF): 597 yield from _traverse_udtfs(scope) 598 elif isinstance(expression, exp.DDL): 599 if isinstance(expression.expression, exp.Query): 600 yield from _traverse_ctes(scope) 601 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 602 return 603 elif isinstance(expression, exp.DML): 604 yield from _traverse_ctes(scope) 605 for query in find_all_in_scope(expression, exp.Query): 606 # This check ensures we don't yield the CTE/nested queries twice 607 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 608 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 609 return 610 else: 611 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 612 return 613 614 yield scope 615 616 617def _traverse_select(scope): 618 yield from _traverse_ctes(scope) 619 yield from _traverse_tables(scope) 620 yield from _traverse_subqueries(scope) 621 622 623def _traverse_union(scope): 624 prev_scope = None 625 union_scope_stack = [scope] 626 expression_stack = [scope.expression.right, scope.expression.left] 627 628 while expression_stack: 629 expression = expression_stack.pop() 630 union_scope = union_scope_stack[-1] 631 632 new_scope = union_scope.branch( 633 expression, 634 outer_columns=union_scope.outer_columns, 635 scope_type=ScopeType.UNION, 636 ) 637 638 if isinstance(expression, exp.SetOperation): 639 yield from _traverse_ctes(new_scope) 640 641 union_scope_stack.append(new_scope) 642 expression_stack.extend([expression.right, expression.left]) 643 continue 644 645 for scope in _traverse_scope(new_scope): 646 yield scope 647 648 if prev_scope: 649 union_scope_stack.pop() 650 union_scope.union_scopes = [prev_scope, scope] 651 prev_scope = union_scope 652 653 yield union_scope 654 else: 655 prev_scope = scope 656 657 658def _traverse_ctes(scope): 659 sources = {} 660 661 for cte in scope.ctes: 662 cte_name = cte.alias 663 664 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 665 # thus the recursive scope is the first section of the union. 666 with_ = scope.expression.args.get("with") 667 if with_ and with_.recursive: 668 union = cte.this 669 670 if isinstance(union, exp.SetOperation): 671 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 672 673 child_scope = None 674 675 for child_scope in _traverse_scope( 676 scope.branch( 677 cte.this, 678 cte_sources=sources, 679 outer_columns=cte.alias_column_names, 680 scope_type=ScopeType.CTE, 681 ) 682 ): 683 yield child_scope 684 685 # append the final child_scope yielded 686 if child_scope: 687 sources[cte_name] = child_scope 688 scope.cte_scopes.append(child_scope) 689 690 scope.sources.update(sources) 691 scope.cte_sources.update(sources) 692 693 694def _is_derived_table(expression: exp.Subquery) -> bool: 695 """ 696 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 697 as it doesn't introduce a new scope. If an alias is present, it shadows all names 698 under the Subquery, so that's one exception to this rule. 699 """ 700 return isinstance(expression, exp.Subquery) and bool( 701 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 702 ) 703 704 705def _is_from_or_join(expression: exp.Expression) -> bool: 706 """ 707 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 708 """ 709 parent = expression.parent 710 711 # Subqueries can be arbitrarily nested 712 while isinstance(parent, exp.Subquery): 713 parent = parent.parent 714 715 return isinstance(parent, (exp.From, exp.Join)) 716 717 718def _traverse_tables(scope): 719 sources = {} 720 721 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 722 expressions = [] 723 from_ = scope.expression.args.get("from") 724 if from_: 725 expressions.append(from_.this) 726 727 for join in scope.expression.args.get("joins") or []: 728 expressions.append(join.this) 729 730 if isinstance(scope.expression, exp.Table): 731 expressions.append(scope.expression) 732 733 expressions.extend(scope.expression.args.get("laterals") or []) 734 735 for expression in expressions: 736 if isinstance(expression, exp.Final): 737 expression = expression.this 738 if isinstance(expression, exp.Table): 739 table_name = expression.name 740 source_name = expression.alias_or_name 741 742 if table_name in scope.sources and not expression.db: 743 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 744 # it is pivoted, because then we get back a new table and hence a new source. 745 pivots = expression.args.get("pivots") 746 if pivots: 747 sources[pivots[0].alias] = expression 748 else: 749 sources[source_name] = scope.sources[table_name] 750 elif source_name in sources: 751 sources[find_new_name(sources, table_name)] = expression 752 else: 753 sources[source_name] = expression 754 755 # Make sure to not include the joins twice 756 if expression is not scope.expression: 757 expressions.extend(join.this for join in expression.args.get("joins") or []) 758 759 continue 760 761 if not isinstance(expression, exp.DerivedTable): 762 continue 763 764 if isinstance(expression, exp.UDTF): 765 lateral_sources = sources 766 scope_type = ScopeType.UDTF 767 scopes = scope.udtf_scopes 768 elif _is_derived_table(expression): 769 lateral_sources = None 770 scope_type = ScopeType.DERIVED_TABLE 771 scopes = scope.derived_table_scopes 772 expressions.extend(join.this for join in expression.args.get("joins") or []) 773 else: 774 # Makes sure we check for possible sources in nested table constructs 775 expressions.append(expression.this) 776 expressions.extend(join.this for join in expression.args.get("joins") or []) 777 continue 778 779 child_scope = None 780 781 for child_scope in _traverse_scope( 782 scope.branch( 783 expression, 784 lateral_sources=lateral_sources, 785 outer_columns=expression.alias_column_names, 786 scope_type=scope_type, 787 ) 788 ): 789 yield child_scope 790 791 # Tables without aliases will be set as "" 792 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 793 # Until then, this means that only a single, unaliased derived table is allowed (rather, 794 # the latest one wins. 795 sources[_get_source_alias(expression)] = child_scope 796 797 # append the final child_scope yielded 798 if child_scope: 799 scopes.append(child_scope) 800 scope.table_scopes.append(child_scope) 801 802 scope.sources.update(sources) 803 804 805def _traverse_subqueries(scope): 806 for subquery in scope.subqueries: 807 top = None 808 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 809 yield child_scope 810 top = child_scope 811 scope.subquery_scopes.append(top) 812 813 814def _traverse_udtfs(scope): 815 if isinstance(scope.expression, exp.Unnest): 816 expressions = scope.expression.expressions 817 elif isinstance(scope.expression, exp.Lateral): 818 expressions = [scope.expression.this] 819 else: 820 expressions = [] 821 822 sources = {} 823 for expression in expressions: 824 if _is_derived_table(expression): 825 top = None 826 for child_scope in _traverse_scope( 827 scope.branch( 828 expression, 829 scope_type=ScopeType.SUBQUERY, 830 outer_columns=expression.alias_column_names, 831 ) 832 ): 833 yield child_scope 834 top = child_scope 835 sources[_get_source_alias(expression)] = child_scope 836 837 scope.subquery_scopes.append(top) 838 839 scope.sources.update(sources) 840 841 842def walk_in_scope(expression, bfs=True, prune=None): 843 """ 844 Returns a generator object which visits all nodes in the syntrax tree, stopping at 845 nodes that start child scopes. 846 847 Args: 848 expression (exp.Expression): 849 bfs (bool): if set to True the BFS traversal order will be applied, 850 otherwise the DFS traversal will be used instead. 851 prune ((node, parent, arg_key) -> bool): callable that returns True if 852 the generator should stop traversing this branch of the tree. 853 854 Yields: 855 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 856 """ 857 # We'll use this variable to pass state into the dfs generator. 858 # Whenever we set it to True, we exclude a subtree from traversal. 859 crossed_scope_boundary = False 860 861 for node in expression.walk( 862 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 863 ): 864 crossed_scope_boundary = False 865 866 yield node 867 868 if node is expression: 869 continue 870 871 if ( 872 isinstance(node, exp.CTE) 873 or ( 874 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 875 and _is_derived_table(node) 876 ) 877 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 878 or isinstance(node, exp.UNWRAPPED_QUERIES) 879 ): 880 crossed_scope_boundary = True 881 882 if isinstance(node, (exp.Subquery, exp.UDTF)): 883 # The following args are not actually in the inner scope, so we should visit them 884 for key in ("joins", "laterals", "pivots"): 885 for arg in node.args.get(key) or []: 886 yield from walk_in_scope(arg, bfs=bfs) 887 888 889def find_all_in_scope(expression, expression_types, bfs=True): 890 """ 891 Returns a generator object which visits all nodes in this scope and only yields those that 892 match at least one of the specified expression types. 893 894 This does NOT traverse into subscopes. 895 896 Args: 897 expression (exp.Expression): 898 expression_types (tuple[type]|type): the expression type(s) to match. 899 bfs (bool): True to use breadth-first search, False to use depth-first. 900 901 Yields: 902 exp.Expression: nodes 903 """ 904 for expression in walk_in_scope(expression, bfs=bfs): 905 if isinstance(expression, tuple(ensure_collection(expression_types))): 906 yield expression 907 908 909def find_in_scope(expression, expression_types, bfs=True): 910 """ 911 Returns the first node in this scope which matches at least one of the specified types. 912 913 This does NOT traverse into subscopes. 914 915 Args: 916 expression (exp.Expression): 917 expression_types (tuple[type]|type): the expression type(s) to match. 918 bfs (bool): True to use breadth-first search, False to use depth-first. 919 920 Returns: 921 exp.Expression: the node which matches the criteria or None if no node matching 922 the criteria was found. 923 """ 924 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) 925 926 927def _get_source_alias(expression): 928 alias_arg = expression.args.get("alias") 929 alias_name = expression.alias 930 931 if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1: 932 alias_name = alias_arg.columns[0].name 933 934 return alias_name
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._join_hints = None 103 self._pivots = None 104 self._references = None 105 self._semi_anti_join_tables = None 106 107 def branch( 108 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 109 ): 110 """Branch from the current scope to a new, inner scope""" 111 return Scope( 112 expression=expression.unnest(), 113 sources=sources.copy() if sources else None, 114 parent=self, 115 scope_type=scope_type, 116 cte_sources={**self.cte_sources, **(cte_sources or {})}, 117 lateral_sources=lateral_sources.copy() if lateral_sources else None, 118 can_be_correlated=self.can_be_correlated 119 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 120 **kwargs, 121 ) 122 123 def _collect(self): 124 self._tables = [] 125 self._ctes = [] 126 self._subqueries = [] 127 self._derived_tables = [] 128 self._udtfs = [] 129 self._raw_columns = [] 130 self._table_columns = [] 131 self._stars = [] 132 self._join_hints = [] 133 self._semi_anti_join_tables = set() 134 135 for node in self.walk(bfs=False): 136 if node is self.expression: 137 continue 138 139 if isinstance(node, exp.Dot) and node.is_star: 140 self._stars.append(node) 141 elif isinstance(node, exp.Column): 142 if isinstance(node.this, exp.Star): 143 self._stars.append(node) 144 else: 145 self._raw_columns.append(node) 146 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 147 parent = node.parent 148 if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: 149 self._semi_anti_join_tables.add(node.alias_or_name) 150 151 self._tables.append(node) 152 elif isinstance(node, exp.JoinHint): 153 self._join_hints.append(node) 154 elif isinstance(node, exp.UDTF): 155 self._udtfs.append(node) 156 elif isinstance(node, exp.CTE): 157 self._ctes.append(node) 158 elif _is_derived_table(node) and _is_from_or_join(node): 159 self._derived_tables.append(node) 160 elif isinstance(node, exp.UNWRAPPED_QUERIES): 161 self._subqueries.append(node) 162 elif isinstance(node, exp.TableColumn): 163 self._table_columns.append(node) 164 165 self._collected = True 166 167 def _ensure_collected(self): 168 if not self._collected: 169 self._collect() 170 171 def walk(self, bfs=True, prune=None): 172 return walk_in_scope(self.expression, bfs=bfs, prune=None) 173 174 def find(self, *expression_types, bfs=True): 175 return find_in_scope(self.expression, expression_types, bfs=bfs) 176 177 def find_all(self, *expression_types, bfs=True): 178 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 179 180 def replace(self, old, new): 181 """ 182 Replace `old` with `new`. 183 184 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 185 186 Args: 187 old (exp.Expression): old node 188 new (exp.Expression): new node 189 """ 190 old.replace(new) 191 self.clear_cache() 192 193 @property 194 def tables(self): 195 """ 196 List of tables in this scope. 197 198 Returns: 199 list[exp.Table]: tables 200 """ 201 self._ensure_collected() 202 return self._tables 203 204 @property 205 def ctes(self): 206 """ 207 List of CTEs in this scope. 208 209 Returns: 210 list[exp.CTE]: ctes 211 """ 212 self._ensure_collected() 213 return self._ctes 214 215 @property 216 def derived_tables(self): 217 """ 218 List of derived tables in this scope. 219 220 For example: 221 SELECT * FROM (SELECT ...) <- that's a derived table 222 223 Returns: 224 list[exp.Subquery]: derived tables 225 """ 226 self._ensure_collected() 227 return self._derived_tables 228 229 @property 230 def udtfs(self): 231 """ 232 List of "User Defined Tabular Functions" in this scope. 233 234 Returns: 235 list[exp.UDTF]: UDTFs 236 """ 237 self._ensure_collected() 238 return self._udtfs 239 240 @property 241 def subqueries(self): 242 """ 243 List of subqueries in this scope. 244 245 For example: 246 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 247 248 Returns: 249 list[exp.Select | exp.SetOperation]: subqueries 250 """ 251 self._ensure_collected() 252 return self._subqueries 253 254 @property 255 def stars(self) -> t.List[exp.Column | exp.Dot]: 256 """ 257 List of star expressions (columns or dots) in this scope. 258 """ 259 self._ensure_collected() 260 return self._stars 261 262 @property 263 def columns(self): 264 """ 265 List of columns in this scope. 266 267 Returns: 268 list[exp.Column]: Column instances in this scope, plus any 269 Columns that reference this scope from correlated subqueries. 270 """ 271 if self._columns is None: 272 self._ensure_collected() 273 columns = self._raw_columns 274 275 external_columns = [ 276 column 277 for scope in itertools.chain( 278 self.subquery_scopes, 279 self.udtf_scopes, 280 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 281 ) 282 for column in scope.external_columns 283 ] 284 285 named_selects = set(self.expression.named_selects) 286 287 self._columns = [] 288 for column in columns + external_columns: 289 ancestor = column.find_ancestor( 290 exp.Select, 291 exp.Qualify, 292 exp.Order, 293 exp.Having, 294 exp.Hint, 295 exp.Table, 296 exp.Star, 297 exp.Distinct, 298 ) 299 if ( 300 not ancestor 301 or column.table 302 or isinstance(ancestor, exp.Select) 303 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 304 or ( 305 isinstance(ancestor, (exp.Order, exp.Distinct)) 306 and ( 307 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 308 or not isinstance(ancestor.parent, exp.Select) 309 or column.name not in named_selects 310 ) 311 ) 312 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 313 ): 314 self._columns.append(column) 315 316 return self._columns 317 318 @property 319 def table_columns(self): 320 if self._table_columns is None: 321 self._ensure_collected() 322 323 return self._table_columns 324 325 @property 326 def selected_sources(self): 327 """ 328 Mapping of nodes and sources that are actually selected from in this scope. 329 330 That is, all tables in a schema are selectable at any point. But a 331 table only becomes a selected source if it's included in a FROM or JOIN clause. 332 333 Returns: 334 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 335 """ 336 if self._selected_sources is None: 337 result = {} 338 339 for name, node in self.references: 340 if name in self._semi_anti_join_tables: 341 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 342 # selected source 343 continue 344 345 if name in result: 346 raise OptimizeError(f"Alias already used: {name}") 347 if name in self.sources: 348 result[name] = (node, self.sources[name]) 349 350 self._selected_sources = result 351 return self._selected_sources 352 353 @property 354 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 355 if self._references is None: 356 self._references = [] 357 358 for table in self.tables: 359 self._references.append((table.alias_or_name, table)) 360 for expression in itertools.chain(self.derived_tables, self.udtfs): 361 self._references.append( 362 ( 363 _get_source_alias(expression), 364 expression if expression.args.get("pivots") else expression.unnest(), 365 ) 366 ) 367 368 return self._references 369 370 @property 371 def external_columns(self): 372 """ 373 Columns that appear to reference sources in outer scopes. 374 375 Returns: 376 list[exp.Column]: Column instances that don't reference 377 sources in the current scope. 378 """ 379 if self._external_columns is None: 380 if isinstance(self.expression, exp.SetOperation): 381 left, right = self.union_scopes 382 self._external_columns = left.external_columns + right.external_columns 383 else: 384 self._external_columns = [ 385 c 386 for c in self.columns 387 if c.table not in self.selected_sources 388 and c.table not in self.semi_or_anti_join_tables 389 ] 390 391 return self._external_columns 392 393 @property 394 def unqualified_columns(self): 395 """ 396 Unqualified columns in the current scope. 397 398 Returns: 399 list[exp.Column]: Unqualified columns 400 """ 401 return [c for c in self.columns if not c.table] 402 403 @property 404 def join_hints(self): 405 """ 406 Hints that exist in the scope that reference tables 407 408 Returns: 409 list[exp.JoinHint]: Join hints that are referenced within the scope 410 """ 411 if self._join_hints is None: 412 return [] 413 return self._join_hints 414 415 @property 416 def pivots(self): 417 if not self._pivots: 418 self._pivots = [ 419 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 420 ] 421 422 return self._pivots 423 424 @property 425 def semi_or_anti_join_tables(self): 426 return self._semi_anti_join_tables or set() 427 428 def source_columns(self, source_name): 429 """ 430 Get all columns in the current scope for a particular source. 431 432 Args: 433 source_name (str): Name of the source 434 Returns: 435 list[exp.Column]: Column instances that reference `source_name` 436 """ 437 return [column for column in self.columns if column.table == source_name] 438 439 @property 440 def is_subquery(self): 441 """Determine if this scope is a subquery""" 442 return self.scope_type == ScopeType.SUBQUERY 443 444 @property 445 def is_derived_table(self): 446 """Determine if this scope is a derived table""" 447 return self.scope_type == ScopeType.DERIVED_TABLE 448 449 @property 450 def is_union(self): 451 """Determine if this scope is a union""" 452 return self.scope_type == ScopeType.UNION 453 454 @property 455 def is_cte(self): 456 """Determine if this scope is a common table expression""" 457 return self.scope_type == ScopeType.CTE 458 459 @property 460 def is_root(self): 461 """Determine if this is the root scope""" 462 return self.scope_type == ScopeType.ROOT 463 464 @property 465 def is_udtf(self): 466 """Determine if this scope is a UDTF (User Defined Table Function)""" 467 return self.scope_type == ScopeType.UDTF 468 469 @property 470 def is_correlated_subquery(self): 471 """Determine if this scope is a correlated subquery""" 472 return bool(self.can_be_correlated and self.external_columns) 473 474 def rename_source(self, old_name, new_name): 475 """Rename a source in this scope""" 476 columns = self.sources.pop(old_name or "", []) 477 self.sources[new_name] = columns 478 479 def add_source(self, name, source): 480 """Add a source to this scope""" 481 self.sources[name] = source 482 self.clear_cache() 483 484 def remove_source(self, name): 485 """Remove a source from this scope""" 486 self.sources.pop(name, None) 487 self.clear_cache() 488 489 def __repr__(self): 490 return f"Scope<{self.expression.sql()}>" 491 492 def traverse(self): 493 """ 494 Traverse the scope tree from this node. 495 496 Yields: 497 Scope: scope instances in depth-first-search post-order 498 """ 499 stack = [self] 500 result = [] 501 while stack: 502 scope = stack.pop() 503 result.append(scope) 504 stack.extend( 505 itertools.chain( 506 scope.cte_scopes, 507 scope.union_scopes, 508 scope.table_scopes, 509 scope.subquery_scopes, 510 ) 511 ) 512 513 yield from reversed(result) 514 515 def ref_count(self): 516 """ 517 Count the number of times each scope in this tree is referenced. 518 519 Returns: 520 dict[int, int]: Mapping of Scope instance ID to reference count 521 """ 522 scope_ref_count = defaultdict(lambda: 0) 523 524 for scope in self.traverse(): 525 for _, source in scope.selected_sources.values(): 526 scope_ref_count[id(source)] += 1 527 528 for name in scope._semi_anti_join_tables: 529 # semi/anti join sources are not actually selected but we still need to 530 # increment their ref count to avoid them being optimized away 531 if name in scope.sources: 532 scope_ref_count[id(scope.sources[name])] += 1 533 534 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._table_columns = None 93 self._stars = None 94 self._derived_tables = None 95 self._udtfs = None 96 self._tables = None 97 self._ctes = None 98 self._subqueries = None 99 self._selected_sources = None 100 self._columns = None 101 self._external_columns = None 102 self._join_hints = None 103 self._pivots = None 104 self._references = None 105 self._semi_anti_join_tables = None
107 def branch( 108 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 109 ): 110 """Branch from the current scope to a new, inner scope""" 111 return Scope( 112 expression=expression.unnest(), 113 sources=sources.copy() if sources else None, 114 parent=self, 115 scope_type=scope_type, 116 cte_sources={**self.cte_sources, **(cte_sources or {})}, 117 lateral_sources=lateral_sources.copy() if lateral_sources else None, 118 can_be_correlated=self.can_be_correlated 119 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 120 **kwargs, 121 )
Branch from the current scope to a new, inner scope
180 def replace(self, old, new): 181 """ 182 Replace `old` with `new`. 183 184 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 185 186 Args: 187 old (exp.Expression): old node 188 new (exp.Expression): new node 189 """ 190 old.replace(new) 191 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
193 @property 194 def tables(self): 195 """ 196 List of tables in this scope. 197 198 Returns: 199 list[exp.Table]: tables 200 """ 201 self._ensure_collected() 202 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
204 @property 205 def ctes(self): 206 """ 207 List of CTEs in this scope. 208 209 Returns: 210 list[exp.CTE]: ctes 211 """ 212 self._ensure_collected() 213 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
215 @property 216 def derived_tables(self): 217 """ 218 List of derived tables in this scope. 219 220 For example: 221 SELECT * FROM (SELECT ...) <- that's a derived table 222 223 Returns: 224 list[exp.Subquery]: derived tables 225 """ 226 self._ensure_collected() 227 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
229 @property 230 def udtfs(self): 231 """ 232 List of "User Defined Tabular Functions" in this scope. 233 234 Returns: 235 list[exp.UDTF]: UDTFs 236 """ 237 self._ensure_collected() 238 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
240 @property 241 def subqueries(self): 242 """ 243 List of subqueries in this scope. 244 245 For example: 246 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 247 248 Returns: 249 list[exp.Select | exp.SetOperation]: subqueries 250 """ 251 self._ensure_collected() 252 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
254 @property 255 def stars(self) -> t.List[exp.Column | exp.Dot]: 256 """ 257 List of star expressions (columns or dots) in this scope. 258 """ 259 self._ensure_collected() 260 return self._stars
List of star expressions (columns or dots) in this scope.
262 @property 263 def columns(self): 264 """ 265 List of columns in this scope. 266 267 Returns: 268 list[exp.Column]: Column instances in this scope, plus any 269 Columns that reference this scope from correlated subqueries. 270 """ 271 if self._columns is None: 272 self._ensure_collected() 273 columns = self._raw_columns 274 275 external_columns = [ 276 column 277 for scope in itertools.chain( 278 self.subquery_scopes, 279 self.udtf_scopes, 280 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 281 ) 282 for column in scope.external_columns 283 ] 284 285 named_selects = set(self.expression.named_selects) 286 287 self._columns = [] 288 for column in columns + external_columns: 289 ancestor = column.find_ancestor( 290 exp.Select, 291 exp.Qualify, 292 exp.Order, 293 exp.Having, 294 exp.Hint, 295 exp.Table, 296 exp.Star, 297 exp.Distinct, 298 ) 299 if ( 300 not ancestor 301 or column.table 302 or isinstance(ancestor, exp.Select) 303 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 304 or ( 305 isinstance(ancestor, (exp.Order, exp.Distinct)) 306 and ( 307 isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) 308 or not isinstance(ancestor.parent, exp.Select) 309 or column.name not in named_selects 310 ) 311 ) 312 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 313 ): 314 self._columns.append(column) 315 316 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
325 @property 326 def selected_sources(self): 327 """ 328 Mapping of nodes and sources that are actually selected from in this scope. 329 330 That is, all tables in a schema are selectable at any point. But a 331 table only becomes a selected source if it's included in a FROM or JOIN clause. 332 333 Returns: 334 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 335 """ 336 if self._selected_sources is None: 337 result = {} 338 339 for name, node in self.references: 340 if name in self._semi_anti_join_tables: 341 # The RHS table of SEMI/ANTI joins shouldn't be collected as a 342 # selected source 343 continue 344 345 if name in result: 346 raise OptimizeError(f"Alias already used: {name}") 347 if name in self.sources: 348 result[name] = (node, self.sources[name]) 349 350 self._selected_sources = result 351 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
353 @property 354 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 355 if self._references is None: 356 self._references = [] 357 358 for table in self.tables: 359 self._references.append((table.alias_or_name, table)) 360 for expression in itertools.chain(self.derived_tables, self.udtfs): 361 self._references.append( 362 ( 363 _get_source_alias(expression), 364 expression if expression.args.get("pivots") else expression.unnest(), 365 ) 366 ) 367 368 return self._references
370 @property 371 def external_columns(self): 372 """ 373 Columns that appear to reference sources in outer scopes. 374 375 Returns: 376 list[exp.Column]: Column instances that don't reference 377 sources in the current scope. 378 """ 379 if self._external_columns is None: 380 if isinstance(self.expression, exp.SetOperation): 381 left, right = self.union_scopes 382 self._external_columns = left.external_columns + right.external_columns 383 else: 384 self._external_columns = [ 385 c 386 for c in self.columns 387 if c.table not in self.selected_sources 388 and c.table not in self.semi_or_anti_join_tables 389 ] 390 391 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
393 @property 394 def unqualified_columns(self): 395 """ 396 Unqualified columns in the current scope. 397 398 Returns: 399 list[exp.Column]: Unqualified columns 400 """ 401 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
403 @property 404 def join_hints(self): 405 """ 406 Hints that exist in the scope that reference tables 407 408 Returns: 409 list[exp.JoinHint]: Join hints that are referenced within the scope 410 """ 411 if self._join_hints is None: 412 return [] 413 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
428 def source_columns(self, source_name): 429 """ 430 Get all columns in the current scope for a particular source. 431 432 Args: 433 source_name (str): Name of the source 434 Returns: 435 list[exp.Column]: Column instances that reference `source_name` 436 """ 437 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
439 @property 440 def is_subquery(self): 441 """Determine if this scope is a subquery""" 442 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
444 @property 445 def is_derived_table(self): 446 """Determine if this scope is a derived table""" 447 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
449 @property 450 def is_union(self): 451 """Determine if this scope is a union""" 452 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
454 @property 455 def is_cte(self): 456 """Determine if this scope is a common table expression""" 457 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
459 @property 460 def is_root(self): 461 """Determine if this is the root scope""" 462 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
464 @property 465 def is_udtf(self): 466 """Determine if this scope is a UDTF (User Defined Table Function)""" 467 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
474 def rename_source(self, old_name, new_name): 475 """Rename a source in this scope""" 476 columns = self.sources.pop(old_name or "", []) 477 self.sources[new_name] = columns
Rename a source in this scope
479 def add_source(self, name, source): 480 """Add a source to this scope""" 481 self.sources[name] = source 482 self.clear_cache()
Add a source to this scope
484 def remove_source(self, name): 485 """Remove a source from this scope""" 486 self.sources.pop(name, None) 487 self.clear_cache()
Remove a source from this scope
492 def traverse(self): 493 """ 494 Traverse the scope tree from this node. 495 496 Yields: 497 Scope: scope instances in depth-first-search post-order 498 """ 499 stack = [self] 500 result = [] 501 while stack: 502 scope = stack.pop() 503 result.append(scope) 504 stack.extend( 505 itertools.chain( 506 scope.cte_scopes, 507 scope.union_scopes, 508 scope.table_scopes, 509 scope.subquery_scopes, 510 ) 511 ) 512 513 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
515 def ref_count(self): 516 """ 517 Count the number of times each scope in this tree is referenced. 518 519 Returns: 520 dict[int, int]: Mapping of Scope instance ID to reference count 521 """ 522 scope_ref_count = defaultdict(lambda: 0) 523 524 for scope in self.traverse(): 525 for _, source in scope.selected_sources.values(): 526 scope_ref_count[id(source)] += 1 527 528 for name in scope._semi_anti_join_tables: 529 # semi/anti join sources are not actually selected but we still need to 530 # increment their ref count to avoid them being optimized away 531 if name in scope.sources: 532 scope_ref_count[id(scope.sources[name])] += 1 533 534 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
537def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 538 """ 539 Traverse an expression by its "scopes". 540 541 "Scope" represents the current context of a Select statement. 542 543 This is helpful for optimizing queries, where we need more information than 544 the expression tree itself. For example, we might care about the source 545 names within a subquery. Returns a list because a generator could result in 546 incomplete properties which is confusing. 547 548 Examples: 549 >>> import sqlglot 550 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 551 >>> scopes = traverse_scope(expression) 552 >>> scopes[0].expression.sql(), list(scopes[0].sources) 553 ('SELECT a FROM x', ['x']) 554 >>> scopes[1].expression.sql(), list(scopes[1].sources) 555 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 556 557 Args: 558 expression: Expression to traverse 559 560 Returns: 561 A list of the created scope instances 562 """ 563 if isinstance(expression, TRAVERSABLES): 564 return list(_traverse_scope(Scope(expression))) 565 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
568def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 569 """ 570 Build a scope tree. 571 572 Args: 573 expression: Expression to build the scope tree for. 574 575 Returns: 576 The root scope 577 """ 578 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
843def walk_in_scope(expression, bfs=True, prune=None): 844 """ 845 Returns a generator object which visits all nodes in the syntrax tree, stopping at 846 nodes that start child scopes. 847 848 Args: 849 expression (exp.Expression): 850 bfs (bool): if set to True the BFS traversal order will be applied, 851 otherwise the DFS traversal will be used instead. 852 prune ((node, parent, arg_key) -> bool): callable that returns True if 853 the generator should stop traversing this branch of the tree. 854 855 Yields: 856 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 857 """ 858 # We'll use this variable to pass state into the dfs generator. 859 # Whenever we set it to True, we exclude a subtree from traversal. 860 crossed_scope_boundary = False 861 862 for node in expression.walk( 863 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 864 ): 865 crossed_scope_boundary = False 866 867 yield node 868 869 if node is expression: 870 continue 871 872 if ( 873 isinstance(node, exp.CTE) 874 or ( 875 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 876 and _is_derived_table(node) 877 ) 878 or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query)) 879 or isinstance(node, exp.UNWRAPPED_QUERIES) 880 ): 881 crossed_scope_boundary = True 882 883 if isinstance(node, (exp.Subquery, exp.UDTF)): 884 # The following args are not actually in the inner scope, so we should visit them 885 for key in ("joins", "laterals", "pivots"): 886 for arg in node.args.get(key) or []: 887 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
890def find_all_in_scope(expression, expression_types, bfs=True): 891 """ 892 Returns a generator object which visits all nodes in this scope and only yields those that 893 match at least one of the specified expression types. 894 895 This does NOT traverse into subscopes. 896 897 Args: 898 expression (exp.Expression): 899 expression_types (tuple[type]|type): the expression type(s) to match. 900 bfs (bool): True to use breadth-first search, False to use depth-first. 901 902 Yields: 903 exp.Expression: nodes 904 """ 905 for expression in walk_in_scope(expression, bfs=bfs): 906 if isinstance(expression, tuple(ensure_collection(expression_types))): 907 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
910def find_in_scope(expression, expression_types, bfs=True): 911 """ 912 Returns the first node in this scope which matches at least one of the specified types. 913 914 This does NOT traverse into subscopes. 915 916 Args: 917 expression (exp.Expression): 918 expression_types (tuple[type]|type): the expression type(s) to match. 919 bfs (bool): True to use breadth-first search, False to use depth-first. 920 921 Returns: 922 exp.Expression: the node which matches the criteria or None if no node matching 923 the criteria was found. 924 """ 925 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.