Edit on GitHub

sqlglot.optimizer.scope

   1from __future__ import annotations
   2
   3import itertools
   4import logging
   5import typing as t
   6from collections import defaultdict
   7from enum import Enum, auto
   8
   9from sqlglot import exp
  10from sqlglot.errors import OptimizeError
  11from sqlglot.helper import find_new_name, mypyc_attr, seq_get
  12from builtins import type as Type
  13
  14logger = logging.getLogger("sqlglot")
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot._typing import E
  18    from collections.abc import Iterator
  19
  20TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
  21
  22
  23class ScopeType(Enum):
  24    ROOT = auto()
  25    SUBQUERY = auto()
  26    DERIVED_TABLE = auto()
  27    CTE = auto()
  28    UNION = auto()
  29    UDTF = auto()
  30
  31
  32@mypyc_attr(native_class=True)
  33class Scope:
  34    """
  35    Selection scope.
  36
  37    Attributes:
  38        expression: Root expression of this scope
  39        sources: Mapping of source name to either
  40            a Table expression or another Scope instance. For example:
  41                SELECT * FROM x                     {"x": Table(this="x")}
  42                SELECT * FROM x AS y                {"y": Table(this="x")}
  43                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
  44        lateral_sources: Sources from laterals
  45            For example:
  46                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
  47            The LATERAL VIEW EXPLODE gets x as a source.
  48        cte_sources: Sources from CTES
  49        outer_columns: If this is a derived table or CTE, and the outer query
  50            defines a column list for the alias of this scope, this is that list of columns.
  51            For example:
  52                SELECT * FROM (SELECT ...) AS y(col1, col2)
  53            The inner query would have `["col1", "col2"]` for its `outer_columns`
  54        parent: Parent scope
  55        scope_type: Type of this scope, relative to it's parent
  56        subquery_scopes: List of all child scopes for subqueries
  57        cte_scopes: List of all child scopes for CTEs
  58        derived_table_scopes: List of all child scopes for derived_tables
  59        udtf_scopes: List of all child scopes for user defined tabular functions
  60        table_scopes: derived_table_scopes + udtf_scopes, in the order that they're defined
  61        union_scopes: If this Scope is for a Union expression, this will be
  62            a list of the left and right child scopes.
  63    """
  64
  65    _collected: bool
  66    _raw_columns: list[exp.Column]
  67    _table_columns: list[exp.TableColumn]
  68    _stars: list[exp.Column | exp.Dot]
  69    _derived_tables: list[exp.Subquery]
  70    _udtfs: list[exp.UDTF]
  71    _tables: list[exp.Table]
  72    _ctes: list[exp.CTE]
  73    _subqueries: list[exp.Select | exp.SetOperation]
  74    _join_hints: list[exp.JoinHint]
  75    _semi_anti_join_tables: set[str]
  76    _column_index: set[int]
  77    _selected_sources: dict[str, tuple[exp.Selectable, exp.Table | Scope]] | None
  78    _columns: list[exp.Column] | None
  79    _external_columns: list[exp.Column] | None
  80    _local_columns: list[exp.Column] | None
  81    _pivots: list[exp.Pivot] | None
  82    _references: list[tuple[str, exp.Selectable]] | None
  83
  84    def __init__(
  85        self,
  86        expression: exp.Expr,
  87        sources: dict[str, exp.Table | Scope] | None = None,
  88        outer_columns: list[str] | None = None,
  89        parent: Scope | None = None,
  90        scope_type: ScopeType = ScopeType.ROOT,
  91        lateral_sources: dict[str, exp.Table | Scope] | None = None,
  92        cte_sources: dict[str, exp.Table | Scope] | None = None,
  93        can_be_correlated: bool | None = None,
  94    ) -> None:
  95        self.expression = expression
  96        self.sources = sources or {}
  97        self.lateral_sources = lateral_sources or {}
  98        self.cte_sources = cte_sources or {}
  99        self.sources.update(self.lateral_sources)
 100        self.sources.update(self.cte_sources)
 101        self.outer_columns = outer_columns or []
 102        self.parent = parent
 103        self.scope_type = scope_type
 104        self.subquery_scopes: list[Scope] = []
 105        self.derived_table_scopes: list[Scope] = []
 106        self.table_scopes: list[Scope] = []
 107        self.cte_scopes: list[Scope] = []
 108        self.union_scopes: list[Scope] = []
 109        self.udtf_scopes: list[Scope] = []
 110        self.can_be_correlated = can_be_correlated
 111        self.clear_cache()
 112
 113    def clear_cache(self) -> None:
 114        self._collected = False
 115        self._raw_columns = []
 116        self._table_columns = []
 117        self._stars = []
 118        self._derived_tables = []
 119        self._udtfs = []
 120        self._tables = []
 121        self._ctes = []
 122        self._subqueries = []
 123        self._join_hints = []
 124        self._semi_anti_join_tables = set()
 125        self._column_index = set()
 126        self._selected_sources = None
 127        self._columns = None
 128        self._external_columns = None
 129        self._local_columns = None
 130        self._pivots = None
 131        self._references = None
 132
 133    def branch(
 134        self,
 135        expression: exp.Expr,
 136        scope_type: ScopeType,
 137        sources: dict[str, exp.Table | Scope] | None = None,
 138        cte_sources: dict[str, exp.Table | Scope] | None = None,
 139        lateral_sources: dict[str, exp.Table | Scope] | None = None,
 140        outer_columns: list[str] | None = None,
 141    ) -> Scope:
 142        """Branch from the current scope to a new, inner scope"""
 143        return Scope(
 144            expression=expression.unnest(),
 145            sources=sources.copy() if sources else None,
 146            parent=self,
 147            scope_type=scope_type,
 148            cte_sources={**self.cte_sources, **(cte_sources or {})},
 149            lateral_sources=lateral_sources.copy() if lateral_sources else None,
 150            can_be_correlated=self.can_be_correlated
 151            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
 152            outer_columns=outer_columns,
 153        )
 154
 155    def _collect(self) -> None:
 156        self._tables = []
 157        self._ctes = []
 158        self._subqueries = []
 159        self._derived_tables = []
 160        self._udtfs = []
 161        self._raw_columns = []
 162        self._table_columns = []
 163        self._stars = []
 164        self._join_hints = []
 165        self._semi_anti_join_tables = set()
 166        self._column_index = set()
 167
 168        for node in self.walk():
 169            if node is self.expression:
 170                continue
 171
 172            if isinstance(node, exp.Dot) and node.is_star:
 173                self._stars.append(node)
 174            elif type(node) is exp.Column:
 175                self._column_index.add(id(node))
 176
 177                if isinstance(node.this, exp.Star):
 178                    self._stars.append(node)
 179                else:
 180                    self._raw_columns.append(node)
 181            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
 182                parent = node.parent
 183                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
 184                    self._semi_anti_join_tables.add(node.alias_or_name)
 185
 186                self._tables.append(node)
 187            elif isinstance(node, exp.JoinHint):
 188                self._join_hints.append(node)
 189            elif isinstance(node, exp.UDTF):
 190                self._udtfs.append(node)
 191            elif isinstance(node, exp.CTE):
 192                self._ctes.append(node)
 193            elif _is_derived_table(node) and _is_from_or_join(node):
 194                self._derived_tables.append(t.cast(exp.Subquery, node))
 195            elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node):
 196                self._subqueries.append(node)
 197            elif isinstance(node, exp.TableColumn):
 198                self._table_columns.append(node)
 199
 200        self._collected = True
 201
 202    def _ensure_collected(self) -> None:
 203        if not self._collected:
 204            self._collect()
 205
 206    def walk(self, prune: t.Callable[[exp.Expr], bool] | None = None) -> Iterator[exp.Expr]:
 207        return walk_in_scope(self.expression, prune=prune)
 208
 209    def find(self, *expression_types: Type[E]) -> E | None:
 210        return find_in_scope(self.expression, *expression_types)
 211
 212    def find_all(self, *expression_types: Type[E]) -> Iterator[E]:
 213        return find_all_in_scope(self.expression, *expression_types)
 214
 215    def replace(self, old: exp.Expr, new: exp.Expr) -> None:
 216        """
 217        Replace `old` with `new`.
 218
 219        This can be used instead of `exp.Expr.replace` to ensure the `Scope` is kept up-to-date.
 220
 221        Args:
 222            old (exp.Expr): old node
 223            new (exp.Expr): new node
 224        """
 225        old.replace(new)
 226        self.clear_cache()
 227
 228    @property
 229    def tables(self) -> list[exp.Table]:
 230        """
 231        List of tables in this scope.
 232
 233        Returns:
 234            list[exp.Table]: tables
 235        """
 236        self._ensure_collected()
 237        return self._tables
 238
 239    @property
 240    def ctes(self) -> list[exp.CTE]:
 241        """
 242        List of CTEs in this scope.
 243
 244        Returns:
 245            list[exp.CTE]: ctes
 246        """
 247        self._ensure_collected()
 248        return self._ctes
 249
 250    @property
 251    def derived_tables(self) -> list[exp.Subquery]:
 252        """
 253        List of derived tables in this scope.
 254
 255        For example:
 256            SELECT * FROM (SELECT ...) <- that's a derived table
 257
 258        Returns:
 259            list[exp.Subquery]: derived tables
 260        """
 261        self._ensure_collected()
 262        return self._derived_tables
 263
 264    @property
 265    def udtfs(self) -> list[exp.UDTF]:
 266        """
 267        List of "User Defined Tabular Functions" in this scope.
 268
 269        Returns:
 270            list[exp.UDTF]: UDTFs
 271        """
 272        self._ensure_collected()
 273        return self._udtfs
 274
 275    @property
 276    def subqueries(self) -> list[exp.Select | exp.SetOperation]:
 277        """
 278        List of subqueries in this scope.
 279
 280        For example:
 281            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
 282
 283        Returns:
 284            list[exp.Select | exp.SetOperation]: subqueries
 285        """
 286        self._ensure_collected()
 287        return self._subqueries
 288
 289    @property
 290    def stars(self) -> list[exp.Column | exp.Dot]:
 291        """
 292        List of star expressions (columns or dots) in this scope.
 293        """
 294        self._ensure_collected()
 295        return self._stars
 296
 297    @property
 298    def column_index(self) -> set[int]:
 299        """
 300        Set of column object IDs that belong to this scope's expression.
 301        """
 302        self._ensure_collected()
 303        return self._column_index
 304
 305    @property
 306    def columns(self) -> list[exp.Column]:
 307        """
 308        List of columns in this scope.
 309
 310        Returns:
 311            list[exp.Column]: Column instances in this scope, plus any
 312                Columns that reference this scope from correlated subqueries.
 313        """
 314        if self._columns is None:
 315            self._ensure_collected()
 316            columns = self._raw_columns
 317
 318            external_columns = [
 319                column
 320                for scope in itertools.chain(
 321                    self.subquery_scopes,
 322                    self.udtf_scopes,
 323                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
 324                )
 325                for column in scope.external_columns
 326            ]
 327
 328            expr = self.expression
 329            named_selects = set(expr.named_selects) if isinstance(expr, exp.Query) else set()
 330
 331            self._columns = []
 332            for column in columns + external_columns:
 333                ancestor = column.find_ancestor(
 334                    exp.Select,
 335                    exp.Qualify,
 336                    exp.Order,
 337                    exp.Having,
 338                    exp.Hint,
 339                    exp.Table,
 340                    exp.Star,
 341                    exp.Distinct,
 342                )
 343                if (
 344                    not ancestor
 345                    or column.text("table")
 346                    or isinstance(ancestor, exp.Select)
 347                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
 348                    or (
 349                        isinstance(ancestor, (exp.Order, exp.Distinct))
 350                        and (
 351                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
 352                            or not isinstance(ancestor.parent, exp.Select)
 353                            or column.name not in named_selects
 354                        )
 355                    )
 356                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
 357                ):
 358                    self._columns.append(column)
 359
 360        return self._columns
 361
 362    @property
 363    def table_columns(self) -> list[exp.TableColumn]:
 364        self._ensure_collected()
 365        return self._table_columns
 366
 367    @property
 368    def selected_sources(self) -> dict[str, tuple[exp.Selectable, exp.Table | Scope]]:
 369        """
 370        Mapping of nodes and sources that are actually selected from in this scope.
 371
 372        That is, all tables in a schema are selectable at any point. But a
 373        table only becomes a selected source if it's included in a FROM or JOIN clause.
 374
 375        Returns:
 376            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
 377        """
 378        if self._selected_sources is None:
 379            result: dict[str, tuple[exp.Selectable, exp.Table | Scope]] = {}
 380
 381            for name, node in self.references:
 382                if name in self._semi_anti_join_tables:
 383                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
 384                    # selected source
 385                    continue
 386
 387                if name in result:
 388                    raise OptimizeError(f"Alias already used: {name}")
 389                if name in self.sources:
 390                    result[name] = (node, self.sources[name])
 391
 392            self._selected_sources = result
 393        return self._selected_sources
 394
 395    @property
 396    def references(self) -> list[tuple[str, exp.Selectable]]:
 397        if self._references is None:
 398            self._references = []
 399
 400            for table in self.tables:
 401                self._references.append((table.alias_or_name, table))
 402            for _expr in itertools.chain(self.derived_tables, self.udtfs):
 403                # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable trait vtable dispatch
 404                expression: exp.Expr = _expr
 405                self._references.append(
 406                    (
 407                        _get_source_alias(expression),
 408                        (
 409                            expression if expression.args.get("pivots") else expression.unnest()
 410                        ).assert_is(exp.Selectable),
 411                    )
 412                )
 413
 414        return self._references
 415
 416    @property
 417    def external_columns(self) -> list[exp.Column]:
 418        """
 419        Columns that appear to reference sources in outer scopes.
 420
 421        Returns:
 422            list[exp.Column]: Column instances that don't reference sources in the current scope.
 423        """
 424        if self._external_columns is None:
 425            if isinstance(self.expression, exp.SetOperation):
 426                left, right = self.union_scopes
 427                self._external_columns = left.external_columns + right.external_columns
 428            else:
 429                local_source_names = {name for name, _ in self.references}
 430                self._external_columns = [
 431                    c
 432                    for c in self.columns
 433                    if c.text("table") not in local_source_names
 434                    and c.text("table") not in self.semi_or_anti_join_tables
 435                ]
 436
 437        return self._external_columns
 438
 439    @property
 440    def local_columns(self) -> list[exp.Column]:
 441        """
 442        Columns in this scope that are not external.
 443
 444        Returns:
 445            list[exp.Column]: Column instances that reference sources in the current scope.
 446        """
 447        if self._local_columns is None:
 448            external_columns = set(self.external_columns)
 449            self._local_columns = [c for c in self.columns if c not in external_columns]
 450
 451        return self._local_columns
 452
 453    @property
 454    def unqualified_columns(self) -> list[exp.Column]:
 455        """
 456        Unqualified columns in the current scope.
 457
 458        Returns:
 459             list[exp.Column]: Unqualified columns
 460        """
 461        return [c for c in self.columns if not c.text("table")]
 462
 463    @property
 464    def join_hints(self) -> list[exp.JoinHint]:
 465        """
 466        Hints that exist in the scope that reference tables
 467
 468        Returns:
 469            list[exp.JoinHint]: Join hints that are referenced within the scope
 470        """
 471        self._ensure_collected()
 472        return self._join_hints
 473
 474    @property
 475    def pivots(self) -> list[exp.Pivot]:
 476        if self._pivots is None:
 477            self._pivots = [
 478                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
 479            ]
 480
 481        return self._pivots
 482
 483    @property
 484    def semi_or_anti_join_tables(self) -> set[str]:
 485        self._ensure_collected()
 486        return self._semi_anti_join_tables
 487
 488    def source_columns(self, source_name: str) -> list[exp.Column]:
 489        """
 490        Get all columns in the current scope for a particular source.
 491
 492        Args:
 493            source_name (str): Name of the source
 494        Returns:
 495            list[exp.Column]: Column instances that reference `source_name`
 496        """
 497        return [column for column in self.columns if column.text("table") == source_name]
 498
 499    @property
 500    def is_subquery(self) -> bool:
 501        """Determine if this scope is a subquery"""
 502        return self.scope_type == ScopeType.SUBQUERY
 503
 504    @property
 505    def is_derived_table(self) -> bool:
 506        """Determine if this scope is a derived table"""
 507        return self.scope_type == ScopeType.DERIVED_TABLE
 508
 509    @property
 510    def is_union(self) -> bool:
 511        """Determine if this scope is a union"""
 512        return self.scope_type == ScopeType.UNION
 513
 514    @property
 515    def is_cte(self) -> bool:
 516        """Determine if this scope is a common table expression"""
 517        return self.scope_type == ScopeType.CTE
 518
 519    @property
 520    def is_root(self) -> bool:
 521        """Determine if this is the root scope"""
 522        return self.scope_type == ScopeType.ROOT
 523
 524    @property
 525    def is_udtf(self) -> bool:
 526        """Determine if this scope is a UDTF (User Defined Table Function)"""
 527        return self.scope_type == ScopeType.UDTF
 528
 529    @property
 530    def is_correlated_subquery(self) -> bool:
 531        """Determine if this scope is a correlated subquery"""
 532        return bool(self.can_be_correlated and self.external_columns)
 533
 534    def rename_source(self, old_name: str | None, new_name: str) -> None:
 535        """Rename a source in this scope"""
 536        old_name = old_name or ""
 537        if old_name in self.sources:
 538            self.sources[new_name] = self.sources.pop(old_name)
 539
 540    def add_source(self, name: str, source: exp.Table | Scope) -> None:
 541        """Add a source to this scope"""
 542        self.sources[name] = source
 543        self.clear_cache()
 544
 545    def remove_source(self, name: str) -> None:
 546        """Remove a source from this scope"""
 547        self.sources.pop(name, None)
 548        self.clear_cache()
 549
 550    def __repr__(self) -> str:
 551        return f"Scope<{self.expression.sql()}>"
 552
 553    def traverse(self) -> Iterator[Scope]:
 554        """
 555        Traverse the scope tree from this node.
 556
 557        Yields:
 558            Scope: scope instances in depth-first-search post-order
 559        """
 560        stack: list[Scope] = [self]
 561        result: list[Scope] = []
 562        while stack:
 563            scope = stack.pop()
 564            result.append(scope)
 565            stack.extend(
 566                itertools.chain(
 567                    scope.cte_scopes,
 568                    scope.union_scopes,
 569                    scope.table_scopes,
 570                    scope.subquery_scopes,
 571                )
 572            )
 573
 574        yield from reversed(result)
 575
 576    def ref_count(self) -> dict[int, int]:
 577        """
 578        Count the number of times each scope in this tree is referenced.
 579
 580        Returns:
 581            dict[int, int]: Mapping of Scope instance ID to reference count
 582        """
 583        scope_ref_count: dict[int, int] = defaultdict(int)
 584
 585        for scope in self.traverse():
 586            for _, source in scope.selected_sources.values():
 587                scope_ref_count[id(source)] += 1
 588
 589            for name in scope._semi_anti_join_tables:
 590                # semi/anti join sources are not actually selected but we still need to
 591                # increment their ref count to avoid them being optimized away
 592                if name in scope.sources:
 593                    scope_ref_count[id(scope.sources[name])] += 1
 594
 595        return scope_ref_count
 596
 597
 598def traverse_scope(expression: exp.Expr) -> list[Scope]:
 599    """
 600    Traverse an expression by its "scopes".
 601
 602    "Scope" represents the current context of a Select statement.
 603
 604    This is helpful for optimizing queries, where we need more information than
 605    the expression tree itself. For example, we might care about the source
 606    names within a subquery. Returns a list because a generator could result in
 607    incomplete properties which is confusing.
 608
 609    Examples:
 610        >>> import sqlglot
 611        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
 612        >>> scopes = traverse_scope(expression)
 613        >>> scopes[0].expression.sql(), list(scopes[0].sources)
 614        ('SELECT a FROM x', ['x'])
 615        >>> scopes[1].expression.sql(), list(scopes[1].sources)
 616        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
 617
 618    Args:
 619        expression: Expr to traverse
 620
 621    Returns:
 622        A list of the created scope instances
 623    """
 624    if isinstance(expression, TRAVERSABLES):
 625        return list(_traverse_scope(Scope(expression)))
 626    return []
 627
 628
 629def build_scope(expression: exp.Expr) -> Scope | None:
 630    """
 631    Build a scope tree.
 632
 633    Args:
 634        expression: Expr to build the scope tree for.
 635
 636    Returns:
 637        The root scope
 638    """
 639    return seq_get(traverse_scope(expression), -1)
 640
 641
 642def _traverse_scope(scope: Scope) -> Iterator[Scope]:
 643    expression = scope.expression
 644
 645    if isinstance(expression, exp.Select):
 646        yield from _traverse_select(scope)
 647    elif isinstance(expression, exp.SetOperation):
 648        yield from _traverse_ctes(scope)
 649        yield from _traverse_union(scope)
 650        return
 651    elif isinstance(expression, exp.Subquery):
 652        if scope.is_root:
 653            yield from _traverse_select(scope)
 654        else:
 655            yield from _traverse_subqueries(scope)
 656    elif isinstance(expression, exp.Table):
 657        yield from _traverse_tables(scope)
 658    elif isinstance(expression, exp.UDTF):
 659        yield from _traverse_udtfs(scope)
 660    elif isinstance(expression, exp.DDL):
 661        # TODO (mypyc): change to ddl_expression = expression.expression
 662        ddl_expression = expression.args.get("expression")
 663        if isinstance(ddl_expression, exp.Query):
 664            yield from _traverse_ctes(scope)
 665            yield from _traverse_scope(Scope(ddl_expression, cte_sources=scope.cte_sources))
 666        return
 667    elif isinstance(expression, exp.DML):
 668        yield from _traverse_ctes(scope)
 669        for query in find_all_in_scope(expression, exp.Query):
 670            # This check ensures we don't yield the CTE/nested queries twice
 671            if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
 672                yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
 673        return
 674    else:
 675        logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
 676        return
 677
 678    yield scope
 679
 680
 681def _traverse_select(scope: Scope) -> Iterator[Scope]:
 682    yield from _traverse_ctes(scope)
 683    yield from _traverse_tables(scope)
 684    yield from _traverse_subqueries(scope)
 685
 686
 687def _traverse_union(scope: Scope) -> Iterator[Scope]:
 688    prev_scope: Scope | None = None
 689    union_scope_stack: list[Scope] = [scope]
 690
 691    set_op = scope.expression
 692    assert isinstance(set_op, exp.SetOperation)
 693    expression_stack: list[exp.Expr] = [set_op.right, set_op.left]
 694
 695    while expression_stack:
 696        expression = expression_stack.pop()
 697        union_scope = union_scope_stack[-1]
 698
 699        new_scope = union_scope.branch(
 700            expression,
 701            outer_columns=union_scope.outer_columns,
 702            scope_type=ScopeType.UNION,
 703        )
 704
 705        if isinstance(expression, exp.SetOperation):
 706            yield from _traverse_ctes(new_scope)
 707
 708            union_scope_stack.append(new_scope)
 709            expression_stack.extend([expression.right, expression.left])
 710            continue
 711
 712        for scope in _traverse_scope(new_scope):
 713            yield scope
 714
 715        if prev_scope:
 716            union_scope_stack.pop()
 717            union_scope.union_scopes = [prev_scope, scope]
 718            prev_scope = union_scope
 719
 720            yield union_scope
 721        else:
 722            prev_scope = scope
 723
 724
 725def _traverse_ctes(scope: Scope) -> Iterator[Scope]:
 726    sources: dict[str, exp.Table | Scope] = {}
 727
 728    for cte in scope.ctes:
 729        cte_name = cte.alias
 730
 731        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
 732        # thus the recursive scope is the first section of the union.
 733        with_ = scope.expression.args.get("with_")
 734        if with_ and with_.recursive:
 735            union = cte.this
 736
 737            if isinstance(union, exp.SetOperation):
 738                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
 739
 740        child_scope: Scope | None = None
 741
 742        for child_scope in _traverse_scope(
 743            scope.branch(
 744                cte.this,
 745                cte_sources=sources,
 746                outer_columns=cte.alias_column_names,
 747                scope_type=ScopeType.CTE,
 748            )
 749        ):
 750            yield child_scope
 751
 752        # append the final child_scope yielded
 753        if child_scope:
 754            sources[cte_name] = child_scope
 755            scope.cte_scopes.append(child_scope)
 756
 757    scope.sources.update(sources)
 758    scope.cte_sources.update(sources)
 759
 760
 761def _is_derived_table(expression: exp.Expr) -> bool:
 762    """
 763    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
 764    as it doesn't introduce a new scope. If an alias is present, it shadows all names
 765    under the Subquery, so that's one exception to this rule.
 766    """
 767    return isinstance(expression, exp.Subquery) and bool(
 768        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
 769    )
 770
 771
 772def _is_from_or_join(expression: exp.Expr) -> bool:
 773    """
 774    Determine if `expression` is the FROM or JOIN clause of a SELECT statement.
 775    """
 776    parent = expression.parent
 777
 778    # Subqueries can be arbitrarily nested
 779    while type(parent) is exp.Subquery:
 780        parent = parent.parent
 781
 782    return type(parent) in (exp.From, exp.Join)
 783
 784
 785def _traverse_tables(scope: Scope) -> Iterator[Scope]:
 786    sources: dict[str, exp.Table | Scope] = {}
 787
 788    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
 789    expressions: list[exp.Expr] = []
 790    from_ = scope.expression.args.get("from_")
 791    if from_:
 792        expressions.append(from_.this)
 793
 794    for join in scope.expression.args.get("joins") or []:
 795        expressions.append(join.this)
 796
 797    if isinstance(scope.expression, exp.Table):
 798        expressions.append(scope.expression)
 799
 800    expressions.extend(scope.expression.args.get("laterals") or [])
 801
 802    for expression in expressions:
 803        if isinstance(expression, exp.Final):
 804            expression = expression.this
 805        if isinstance(expression, exp.Table):
 806            table_name = expression.name
 807            source_name = expression.alias_or_name
 808
 809            if table_name in scope.sources and not expression.db:
 810                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
 811                # it is pivoted, because then we get back a new table and hence a new source.
 812                pivots = expression.args.get("pivots")
 813                if pivots:
 814                    sources[pivots[0].alias] = expression
 815                else:
 816                    sources[source_name] = scope.sources[table_name]
 817            elif source_name in sources:
 818                sources[find_new_name(sources, table_name)] = expression
 819            else:
 820                sources[source_name] = expression
 821
 822            # Make sure to not include the joins twice
 823            if expression is not scope.expression:
 824                expressions.extend(join.this for join in expression.args.get("joins") or [])
 825
 826            continue
 827
 828        if not isinstance(expression, exp.DerivedTable):
 829            continue
 830
 831        # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable/UDTF trait vtable dispatch
 832        node: exp.Expr = expression
 833
 834        if isinstance(expression, exp.UDTF):
 835            lateral_sources = sources
 836            scope_type = ScopeType.UDTF
 837            scopes = scope.udtf_scopes
 838        elif _is_derived_table(expression):
 839            lateral_sources = None
 840            scope_type = ScopeType.DERIVED_TABLE
 841            scopes = scope.derived_table_scopes
 842            expressions.extend(join.this for join in node.args.get("joins") or [])
 843        else:
 844            # Makes sure we check for possible sources in nested table constructs
 845            expressions.append(node.this)
 846            expressions.extend(join.this for join in node.args.get("joins") or [])
 847            continue
 848
 849        child_scope: Scope | None = None
 850
 851        for child_scope in _traverse_scope(
 852            scope.branch(
 853                node,
 854                lateral_sources=lateral_sources,
 855                outer_columns=node.alias_column_names,
 856                scope_type=scope_type,
 857            )
 858        ):
 859            yield child_scope
 860
 861            # Tables without aliases will be set as ""
 862            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
 863            # Until then, this means that only a single, unaliased derived table is allowed (rather,
 864            # the latest one wins.
 865            sources[_get_source_alias(node)] = child_scope
 866
 867        # append the final child_scope yielded
 868        if child_scope:
 869            scopes.append(child_scope)
 870            scope.table_scopes.append(child_scope)
 871
 872    scope.sources.update(sources)
 873
 874
 875def _traverse_subqueries(scope: Scope) -> Iterator[Scope]:
 876    for subquery in scope.subqueries:
 877        top: Scope | None = None
 878        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
 879            yield child_scope
 880            top = child_scope
 881        if top is not None:
 882            scope.subquery_scopes.append(top)
 883
 884
 885def _traverse_udtfs(scope: Scope) -> Iterator[Scope]:
 886    if isinstance(scope.expression, exp.Unnest):
 887        udtf_expressions = scope.expression.expressions
 888    elif isinstance(scope.expression, exp.Lateral):
 889        udtf_expressions = [scope.expression.this]
 890    else:
 891        udtf_expressions = []
 892
 893    sources: dict[str, exp.Table | Scope] = {}
 894    for expression in udtf_expressions:
 895        if isinstance(expression, exp.Subquery):
 896            top: Scope | None = None
 897            for child_scope in _traverse_scope(
 898                scope.branch(
 899                    expression,
 900                    scope_type=ScopeType.SUBQUERY,
 901                    outer_columns=expression.alias_column_names,
 902                )
 903            ):
 904                yield child_scope
 905                top = child_scope
 906                sources[_get_source_alias(expression)] = child_scope
 907
 908            if top is not None:
 909                scope.subquery_scopes.append(top)
 910
 911    scope.sources.update(sources)
 912
 913
 914def walk_in_scope(
 915    expression: exp.Expr,
 916    prune: t.Callable[[exp.Expr], bool] | None = None,
 917) -> Iterator[exp.Expr]:
 918    """
 919    Returns a generator object which visits all nodes in the syntrax tree, stopping at
 920    nodes that start child scopes. This does a custom DFS traversal rather than using
 921    expression.walk() because the nested generators aren't optimized in mypyc.
 922
 923    Args:
 924        expression:
 925        prune: callable that returns True if
 926            the generator should stop traversing this branch of the tree.
 927
 928    Yields:
 929        exp.Expr: each node in scope
 930    """
 931    stack: list[exp.Expr] = [expression]
 932
 933    while stack:
 934        node = stack.pop()
 935
 936        yield node
 937
 938        if node is not expression and (
 939            isinstance(node, exp.CTE)
 940            or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node))
 941            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
 942            or isinstance(node, exp.UNWRAPPED_QUERIES)
 943        ):
 944            if isinstance(node, (exp.Subquery, exp.UDTF)):
 945                for key in ("joins", "laterals", "pivots"):
 946                    for arg in node.args.get(key) or []:
 947                        yield from walk_in_scope(arg)
 948            continue
 949
 950        if prune and prune(node):
 951            continue
 952
 953        for vs in reversed(node.args.values()):
 954            if isinstance(vs, list):
 955                for v in reversed(vs):
 956                    if isinstance(v, exp.Expr):
 957                        stack.append(v)
 958            elif isinstance(vs, exp.Expr):
 959                stack.append(vs)
 960
 961
 962def find_all_in_scope(
 963    expression: exp.Expr,
 964    *expression_types: Type[E],
 965) -> Iterator[E]:
 966    """
 967    Returns a generator object which visits all nodes in this scope and only yields those that
 968    match at least one of the specified expression types.
 969
 970    This does NOT traverse into subscopes.
 971
 972    Args:
 973        expression: the expression to search.
 974        expression_types: the expression type(s) to match.
 975
 976    Yields:
 977        The matching nodes.
 978    """
 979    for node in walk_in_scope(expression):
 980        if isinstance(node, expression_types):
 981            yield node
 982
 983
 984def find_in_scope(
 985    expression: exp.Expr,
 986    *expression_types: Type[E],
 987) -> E | None:
 988    """
 989    Returns the first node in this scope which matches at least one of the specified types.
 990
 991    This does NOT traverse into subscopes.
 992
 993    Args:
 994        expression: the expression to search.
 995        expression_types: the expression type(s) to match.
 996
 997    Returns:
 998        The node which matches the criteria or None if no node matching
 999        the criteria was found.
1000    """
1001    return next(find_all_in_scope(expression, *expression_types), None)
1002
1003
1004def _get_source_alias(expression: exp.Expr) -> str:
1005    alias_arg = expression.args.get("alias")
1006    alias_name = expression.alias
1007
1008    if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1:
1009        alias_name = alias_arg.columns[0].name
1010
1011    return alias_name
logger = <Logger sqlglot (WARNING)>
class ScopeType(enum.Enum):
24class ScopeType(Enum):
25    ROOT = auto()
26    SUBQUERY = auto()
27    DERIVED_TABLE = auto()
28    CTE = auto()
29    UNION = auto()
30    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
@mypyc_attr(native_class=True)
class Scope:
 33@mypyc_attr(native_class=True)
 34class Scope:
 35    """
 36    Selection scope.
 37
 38    Attributes:
 39        expression: Root expression of this scope
 40        sources: Mapping of source name to either
 41            a Table expression or another Scope instance. For example:
 42                SELECT * FROM x                     {"x": Table(this="x")}
 43                SELECT * FROM x AS y                {"y": Table(this="x")}
 44                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 45        lateral_sources: Sources from laterals
 46            For example:
 47                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 48            The LATERAL VIEW EXPLODE gets x as a source.
 49        cte_sources: Sources from CTES
 50        outer_columns: If this is a derived table or CTE, and the outer query
 51            defines a column list for the alias of this scope, this is that list of columns.
 52            For example:
 53                SELECT * FROM (SELECT ...) AS y(col1, col2)
 54            The inner query would have `["col1", "col2"]` for its `outer_columns`
 55        parent: Parent scope
 56        scope_type: Type of this scope, relative to it's parent
 57        subquery_scopes: List of all child scopes for subqueries
 58        cte_scopes: List of all child scopes for CTEs
 59        derived_table_scopes: List of all child scopes for derived_tables
 60        udtf_scopes: List of all child scopes for user defined tabular functions
 61        table_scopes: derived_table_scopes + udtf_scopes, in the order that they're defined
 62        union_scopes: If this Scope is for a Union expression, this will be
 63            a list of the left and right child scopes.
 64    """
 65
 66    _collected: bool
 67    _raw_columns: list[exp.Column]
 68    _table_columns: list[exp.TableColumn]
 69    _stars: list[exp.Column | exp.Dot]
 70    _derived_tables: list[exp.Subquery]
 71    _udtfs: list[exp.UDTF]
 72    _tables: list[exp.Table]
 73    _ctes: list[exp.CTE]
 74    _subqueries: list[exp.Select | exp.SetOperation]
 75    _join_hints: list[exp.JoinHint]
 76    _semi_anti_join_tables: set[str]
 77    _column_index: set[int]
 78    _selected_sources: dict[str, tuple[exp.Selectable, exp.Table | Scope]] | None
 79    _columns: list[exp.Column] | None
 80    _external_columns: list[exp.Column] | None
 81    _local_columns: list[exp.Column] | None
 82    _pivots: list[exp.Pivot] | None
 83    _references: list[tuple[str, exp.Selectable]] | None
 84
 85    def __init__(
 86        self,
 87        expression: exp.Expr,
 88        sources: dict[str, exp.Table | Scope] | None = None,
 89        outer_columns: list[str] | None = None,
 90        parent: Scope | None = None,
 91        scope_type: ScopeType = ScopeType.ROOT,
 92        lateral_sources: dict[str, exp.Table | Scope] | None = None,
 93        cte_sources: dict[str, exp.Table | Scope] | None = None,
 94        can_be_correlated: bool | None = None,
 95    ) -> None:
 96        self.expression = expression
 97        self.sources = sources or {}
 98        self.lateral_sources = lateral_sources or {}
 99        self.cte_sources = cte_sources or {}
100        self.sources.update(self.lateral_sources)
101        self.sources.update(self.cte_sources)
102        self.outer_columns = outer_columns or []
103        self.parent = parent
104        self.scope_type = scope_type
105        self.subquery_scopes: list[Scope] = []
106        self.derived_table_scopes: list[Scope] = []
107        self.table_scopes: list[Scope] = []
108        self.cte_scopes: list[Scope] = []
109        self.union_scopes: list[Scope] = []
110        self.udtf_scopes: list[Scope] = []
111        self.can_be_correlated = can_be_correlated
112        self.clear_cache()
113
114    def clear_cache(self) -> None:
115        self._collected = False
116        self._raw_columns = []
117        self._table_columns = []
118        self._stars = []
119        self._derived_tables = []
120        self._udtfs = []
121        self._tables = []
122        self._ctes = []
123        self._subqueries = []
124        self._join_hints = []
125        self._semi_anti_join_tables = set()
126        self._column_index = set()
127        self._selected_sources = None
128        self._columns = None
129        self._external_columns = None
130        self._local_columns = None
131        self._pivots = None
132        self._references = None
133
134    def branch(
135        self,
136        expression: exp.Expr,
137        scope_type: ScopeType,
138        sources: dict[str, exp.Table | Scope] | None = None,
139        cte_sources: dict[str, exp.Table | Scope] | None = None,
140        lateral_sources: dict[str, exp.Table | Scope] | None = None,
141        outer_columns: list[str] | None = None,
142    ) -> Scope:
143        """Branch from the current scope to a new, inner scope"""
144        return Scope(
145            expression=expression.unnest(),
146            sources=sources.copy() if sources else None,
147            parent=self,
148            scope_type=scope_type,
149            cte_sources={**self.cte_sources, **(cte_sources or {})},
150            lateral_sources=lateral_sources.copy() if lateral_sources else None,
151            can_be_correlated=self.can_be_correlated
152            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
153            outer_columns=outer_columns,
154        )
155
156    def _collect(self) -> None:
157        self._tables = []
158        self._ctes = []
159        self._subqueries = []
160        self._derived_tables = []
161        self._udtfs = []
162        self._raw_columns = []
163        self._table_columns = []
164        self._stars = []
165        self._join_hints = []
166        self._semi_anti_join_tables = set()
167        self._column_index = set()
168
169        for node in self.walk():
170            if node is self.expression:
171                continue
172
173            if isinstance(node, exp.Dot) and node.is_star:
174                self._stars.append(node)
175            elif type(node) is exp.Column:
176                self._column_index.add(id(node))
177
178                if isinstance(node.this, exp.Star):
179                    self._stars.append(node)
180                else:
181                    self._raw_columns.append(node)
182            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
183                parent = node.parent
184                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
185                    self._semi_anti_join_tables.add(node.alias_or_name)
186
187                self._tables.append(node)
188            elif isinstance(node, exp.JoinHint):
189                self._join_hints.append(node)
190            elif isinstance(node, exp.UDTF):
191                self._udtfs.append(node)
192            elif isinstance(node, exp.CTE):
193                self._ctes.append(node)
194            elif _is_derived_table(node) and _is_from_or_join(node):
195                self._derived_tables.append(t.cast(exp.Subquery, node))
196            elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node):
197                self._subqueries.append(node)
198            elif isinstance(node, exp.TableColumn):
199                self._table_columns.append(node)
200
201        self._collected = True
202
203    def _ensure_collected(self) -> None:
204        if not self._collected:
205            self._collect()
206
207    def walk(self, prune: t.Callable[[exp.Expr], bool] | None = None) -> Iterator[exp.Expr]:
208        return walk_in_scope(self.expression, prune=prune)
209
210    def find(self, *expression_types: Type[E]) -> E | None:
211        return find_in_scope(self.expression, *expression_types)
212
213    def find_all(self, *expression_types: Type[E]) -> Iterator[E]:
214        return find_all_in_scope(self.expression, *expression_types)
215
216    def replace(self, old: exp.Expr, new: exp.Expr) -> None:
217        """
218        Replace `old` with `new`.
219
220        This can be used instead of `exp.Expr.replace` to ensure the `Scope` is kept up-to-date.
221
222        Args:
223            old (exp.Expr): old node
224            new (exp.Expr): new node
225        """
226        old.replace(new)
227        self.clear_cache()
228
229    @property
230    def tables(self) -> list[exp.Table]:
231        """
232        List of tables in this scope.
233
234        Returns:
235            list[exp.Table]: tables
236        """
237        self._ensure_collected()
238        return self._tables
239
240    @property
241    def ctes(self) -> list[exp.CTE]:
242        """
243        List of CTEs in this scope.
244
245        Returns:
246            list[exp.CTE]: ctes
247        """
248        self._ensure_collected()
249        return self._ctes
250
251    @property
252    def derived_tables(self) -> list[exp.Subquery]:
253        """
254        List of derived tables in this scope.
255
256        For example:
257            SELECT * FROM (SELECT ...) <- that's a derived table
258
259        Returns:
260            list[exp.Subquery]: derived tables
261        """
262        self._ensure_collected()
263        return self._derived_tables
264
265    @property
266    def udtfs(self) -> list[exp.UDTF]:
267        """
268        List of "User Defined Tabular Functions" in this scope.
269
270        Returns:
271            list[exp.UDTF]: UDTFs
272        """
273        self._ensure_collected()
274        return self._udtfs
275
276    @property
277    def subqueries(self) -> list[exp.Select | exp.SetOperation]:
278        """
279        List of subqueries in this scope.
280
281        For example:
282            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
283
284        Returns:
285            list[exp.Select | exp.SetOperation]: subqueries
286        """
287        self._ensure_collected()
288        return self._subqueries
289
290    @property
291    def stars(self) -> list[exp.Column | exp.Dot]:
292        """
293        List of star expressions (columns or dots) in this scope.
294        """
295        self._ensure_collected()
296        return self._stars
297
298    @property
299    def column_index(self) -> set[int]:
300        """
301        Set of column object IDs that belong to this scope's expression.
302        """
303        self._ensure_collected()
304        return self._column_index
305
306    @property
307    def columns(self) -> list[exp.Column]:
308        """
309        List of columns in this scope.
310
311        Returns:
312            list[exp.Column]: Column instances in this scope, plus any
313                Columns that reference this scope from correlated subqueries.
314        """
315        if self._columns is None:
316            self._ensure_collected()
317            columns = self._raw_columns
318
319            external_columns = [
320                column
321                for scope in itertools.chain(
322                    self.subquery_scopes,
323                    self.udtf_scopes,
324                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
325                )
326                for column in scope.external_columns
327            ]
328
329            expr = self.expression
330            named_selects = set(expr.named_selects) if isinstance(expr, exp.Query) else set()
331
332            self._columns = []
333            for column in columns + external_columns:
334                ancestor = column.find_ancestor(
335                    exp.Select,
336                    exp.Qualify,
337                    exp.Order,
338                    exp.Having,
339                    exp.Hint,
340                    exp.Table,
341                    exp.Star,
342                    exp.Distinct,
343                )
344                if (
345                    not ancestor
346                    or column.text("table")
347                    or isinstance(ancestor, exp.Select)
348                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
349                    or (
350                        isinstance(ancestor, (exp.Order, exp.Distinct))
351                        and (
352                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
353                            or not isinstance(ancestor.parent, exp.Select)
354                            or column.name not in named_selects
355                        )
356                    )
357                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
358                ):
359                    self._columns.append(column)
360
361        return self._columns
362
363    @property
364    def table_columns(self) -> list[exp.TableColumn]:
365        self._ensure_collected()
366        return self._table_columns
367
368    @property
369    def selected_sources(self) -> dict[str, tuple[exp.Selectable, exp.Table | Scope]]:
370        """
371        Mapping of nodes and sources that are actually selected from in this scope.
372
373        That is, all tables in a schema are selectable at any point. But a
374        table only becomes a selected source if it's included in a FROM or JOIN clause.
375
376        Returns:
377            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
378        """
379        if self._selected_sources is None:
380            result: dict[str, tuple[exp.Selectable, exp.Table | Scope]] = {}
381
382            for name, node in self.references:
383                if name in self._semi_anti_join_tables:
384                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
385                    # selected source
386                    continue
387
388                if name in result:
389                    raise OptimizeError(f"Alias already used: {name}")
390                if name in self.sources:
391                    result[name] = (node, self.sources[name])
392
393            self._selected_sources = result
394        return self._selected_sources
395
396    @property
397    def references(self) -> list[tuple[str, exp.Selectable]]:
398        if self._references is None:
399            self._references = []
400
401            for table in self.tables:
402                self._references.append((table.alias_or_name, table))
403            for _expr in itertools.chain(self.derived_tables, self.udtfs):
404                # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable trait vtable dispatch
405                expression: exp.Expr = _expr
406                self._references.append(
407                    (
408                        _get_source_alias(expression),
409                        (
410                            expression if expression.args.get("pivots") else expression.unnest()
411                        ).assert_is(exp.Selectable),
412                    )
413                )
414
415        return self._references
416
417    @property
418    def external_columns(self) -> list[exp.Column]:
419        """
420        Columns that appear to reference sources in outer scopes.
421
422        Returns:
423            list[exp.Column]: Column instances that don't reference sources in the current scope.
424        """
425        if self._external_columns is None:
426            if isinstance(self.expression, exp.SetOperation):
427                left, right = self.union_scopes
428                self._external_columns = left.external_columns + right.external_columns
429            else:
430                local_source_names = {name for name, _ in self.references}
431                self._external_columns = [
432                    c
433                    for c in self.columns
434                    if c.text("table") not in local_source_names
435                    and c.text("table") not in self.semi_or_anti_join_tables
436                ]
437
438        return self._external_columns
439
440    @property
441    def local_columns(self) -> list[exp.Column]:
442        """
443        Columns in this scope that are not external.
444
445        Returns:
446            list[exp.Column]: Column instances that reference sources in the current scope.
447        """
448        if self._local_columns is None:
449            external_columns = set(self.external_columns)
450            self._local_columns = [c for c in self.columns if c not in external_columns]
451
452        return self._local_columns
453
454    @property
455    def unqualified_columns(self) -> list[exp.Column]:
456        """
457        Unqualified columns in the current scope.
458
459        Returns:
460             list[exp.Column]: Unqualified columns
461        """
462        return [c for c in self.columns if not c.text("table")]
463
464    @property
465    def join_hints(self) -> list[exp.JoinHint]:
466        """
467        Hints that exist in the scope that reference tables
468
469        Returns:
470            list[exp.JoinHint]: Join hints that are referenced within the scope
471        """
472        self._ensure_collected()
473        return self._join_hints
474
475    @property
476    def pivots(self) -> list[exp.Pivot]:
477        if self._pivots is None:
478            self._pivots = [
479                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
480            ]
481
482        return self._pivots
483
484    @property
485    def semi_or_anti_join_tables(self) -> set[str]:
486        self._ensure_collected()
487        return self._semi_anti_join_tables
488
489    def source_columns(self, source_name: str) -> list[exp.Column]:
490        """
491        Get all columns in the current scope for a particular source.
492
493        Args:
494            source_name (str): Name of the source
495        Returns:
496            list[exp.Column]: Column instances that reference `source_name`
497        """
498        return [column for column in self.columns if column.text("table") == source_name]
499
500    @property
501    def is_subquery(self) -> bool:
502        """Determine if this scope is a subquery"""
503        return self.scope_type == ScopeType.SUBQUERY
504
505    @property
506    def is_derived_table(self) -> bool:
507        """Determine if this scope is a derived table"""
508        return self.scope_type == ScopeType.DERIVED_TABLE
509
510    @property
511    def is_union(self) -> bool:
512        """Determine if this scope is a union"""
513        return self.scope_type == ScopeType.UNION
514
515    @property
516    def is_cte(self) -> bool:
517        """Determine if this scope is a common table expression"""
518        return self.scope_type == ScopeType.CTE
519
520    @property
521    def is_root(self) -> bool:
522        """Determine if this is the root scope"""
523        return self.scope_type == ScopeType.ROOT
524
525    @property
526    def is_udtf(self) -> bool:
527        """Determine if this scope is a UDTF (User Defined Table Function)"""
528        return self.scope_type == ScopeType.UDTF
529
530    @property
531    def is_correlated_subquery(self) -> bool:
532        """Determine if this scope is a correlated subquery"""
533        return bool(self.can_be_correlated and self.external_columns)
534
535    def rename_source(self, old_name: str | None, new_name: str) -> None:
536        """Rename a source in this scope"""
537        old_name = old_name or ""
538        if old_name in self.sources:
539            self.sources[new_name] = self.sources.pop(old_name)
540
541    def add_source(self, name: str, source: exp.Table | Scope) -> None:
542        """Add a source to this scope"""
543        self.sources[name] = source
544        self.clear_cache()
545
546    def remove_source(self, name: str) -> None:
547        """Remove a source from this scope"""
548        self.sources.pop(name, None)
549        self.clear_cache()
550
551    def __repr__(self) -> str:
552        return f"Scope<{self.expression.sql()}>"
553
554    def traverse(self) -> Iterator[Scope]:
555        """
556        Traverse the scope tree from this node.
557
558        Yields:
559            Scope: scope instances in depth-first-search post-order
560        """
561        stack: list[Scope] = [self]
562        result: list[Scope] = []
563        while stack:
564            scope = stack.pop()
565            result.append(scope)
566            stack.extend(
567                itertools.chain(
568                    scope.cte_scopes,
569                    scope.union_scopes,
570                    scope.table_scopes,
571                    scope.subquery_scopes,
572                )
573            )
574
575        yield from reversed(result)
576
577    def ref_count(self) -> dict[int, int]:
578        """
579        Count the number of times each scope in this tree is referenced.
580
581        Returns:
582            dict[int, int]: Mapping of Scope instance ID to reference count
583        """
584        scope_ref_count: dict[int, int] = defaultdict(int)
585
586        for scope in self.traverse():
587            for _, source in scope.selected_sources.values():
588                scope_ref_count[id(source)] += 1
589
590            for name in scope._semi_anti_join_tables:
591                # semi/anti join sources are not actually selected but we still need to
592                # increment their ref count to avoid them being optimized away
593                if name in scope.sources:
594                    scope_ref_count[id(scope.sources[name])] += 1
595
596        return scope_ref_count

Selection scope.

Attributes:
  • expression: Root expression of this scope
  • sources: Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources: Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • cte_sources: Sources from CTES
  • outer_columns: If this is a derived table or CTE, and the outer query defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_columns
  • parent: Parent scope
  • scope_type: Type of this scope, relative to it's parent
  • subquery_scopes: List of all child scopes for subqueries
  • cte_scopes: List of all child scopes for CTEs
  • derived_table_scopes: List of all child scopes for derived_tables
  • udtf_scopes: List of all child scopes for user defined tabular functions
  • table_scopes: derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes: If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression: sqlglot.expressions.core.Expr, sources: dict[str, sqlglot.expressions.query.Table | Scope] | None = None, outer_columns: list[str] | None = None, parent: Scope | None = None, scope_type: ScopeType = <ScopeType.ROOT: 1>, lateral_sources: dict[str, sqlglot.expressions.query.Table | Scope] | None = None, cte_sources: dict[str, sqlglot.expressions.query.Table | Scope] | None = None, can_be_correlated: bool | None = None)
 85    def __init__(
 86        self,
 87        expression: exp.Expr,
 88        sources: dict[str, exp.Table | Scope] | None = None,
 89        outer_columns: list[str] | None = None,
 90        parent: Scope | None = None,
 91        scope_type: ScopeType = ScopeType.ROOT,
 92        lateral_sources: dict[str, exp.Table | Scope] | None = None,
 93        cte_sources: dict[str, exp.Table | Scope] | None = None,
 94        can_be_correlated: bool | None = None,
 95    ) -> None:
 96        self.expression = expression
 97        self.sources = sources or {}
 98        self.lateral_sources = lateral_sources or {}
 99        self.cte_sources = cte_sources or {}
100        self.sources.update(self.lateral_sources)
101        self.sources.update(self.cte_sources)
102        self.outer_columns = outer_columns or []
103        self.parent = parent
104        self.scope_type = scope_type
105        self.subquery_scopes: list[Scope] = []
106        self.derived_table_scopes: list[Scope] = []
107        self.table_scopes: list[Scope] = []
108        self.cte_scopes: list[Scope] = []
109        self.union_scopes: list[Scope] = []
110        self.udtf_scopes: list[Scope] = []
111        self.can_be_correlated = can_be_correlated
112        self.clear_cache()
expression
sources
lateral_sources
cte_sources
outer_columns
parent
scope_type
subquery_scopes: list[Scope]
derived_table_scopes: list[Scope]
table_scopes: list[Scope]
cte_scopes: list[Scope]
union_scopes: list[Scope]
udtf_scopes: list[Scope]
can_be_correlated
def clear_cache(self) -> None:
114    def clear_cache(self) -> None:
115        self._collected = False
116        self._raw_columns = []
117        self._table_columns = []
118        self._stars = []
119        self._derived_tables = []
120        self._udtfs = []
121        self._tables = []
122        self._ctes = []
123        self._subqueries = []
124        self._join_hints = []
125        self._semi_anti_join_tables = set()
126        self._column_index = set()
127        self._selected_sources = None
128        self._columns = None
129        self._external_columns = None
130        self._local_columns = None
131        self._pivots = None
132        self._references = None
def branch( self, expression: sqlglot.expressions.core.Expr, scope_type: ScopeType, sources: dict[str, sqlglot.expressions.query.Table | Scope] | None = None, cte_sources: dict[str, sqlglot.expressions.query.Table | Scope] | None = None, lateral_sources: dict[str, sqlglot.expressions.query.Table | Scope] | None = None, outer_columns: list[str] | None = None) -> Scope:
134    def branch(
135        self,
136        expression: exp.Expr,
137        scope_type: ScopeType,
138        sources: dict[str, exp.Table | Scope] | None = None,
139        cte_sources: dict[str, exp.Table | Scope] | None = None,
140        lateral_sources: dict[str, exp.Table | Scope] | None = None,
141        outer_columns: list[str] | None = None,
142    ) -> Scope:
143        """Branch from the current scope to a new, inner scope"""
144        return Scope(
145            expression=expression.unnest(),
146            sources=sources.copy() if sources else None,
147            parent=self,
148            scope_type=scope_type,
149            cte_sources={**self.cte_sources, **(cte_sources or {})},
150            lateral_sources=lateral_sources.copy() if lateral_sources else None,
151            can_be_correlated=self.can_be_correlated
152            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
153            outer_columns=outer_columns,
154        )

Branch from the current scope to a new, inner scope

def walk( self, prune: Optional[Callable[[sqlglot.expressions.core.Expr], bool]] = None) -> Iterator[sqlglot.expressions.core.Expr]:
207    def walk(self, prune: t.Callable[[exp.Expr], bool] | None = None) -> Iterator[exp.Expr]:
208        return walk_in_scope(self.expression, prune=prune)
def find(self, *expression_types: type[~E]) -> Optional[~E]:
210    def find(self, *expression_types: Type[E]) -> E | None:
211        return find_in_scope(self.expression, *expression_types)
def find_all(self, *expression_types: type[~E]) -> Iterator[~E]:
213    def find_all(self, *expression_types: Type[E]) -> Iterator[E]:
214        return find_all_in_scope(self.expression, *expression_types)
def replace( self, old: sqlglot.expressions.core.Expr, new: sqlglot.expressions.core.Expr) -> None:
216    def replace(self, old: exp.Expr, new: exp.Expr) -> None:
217        """
218        Replace `old` with `new`.
219
220        This can be used instead of `exp.Expr.replace` to ensure the `Scope` is kept up-to-date.
221
222        Args:
223            old (exp.Expr): old node
224            new (exp.Expr): new node
225        """
226        old.replace(new)
227        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expr.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expr): old node
  • new (exp.Expr): new node
tables: list[sqlglot.expressions.query.Table]
229    @property
230    def tables(self) -> list[exp.Table]:
231        """
232        List of tables in this scope.
233
234        Returns:
235            list[exp.Table]: tables
236        """
237        self._ensure_collected()
238        return self._tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes: list[sqlglot.expressions.query.CTE]
240    @property
241    def ctes(self) -> list[exp.CTE]:
242        """
243        List of CTEs in this scope.
244
245        Returns:
246            list[exp.CTE]: ctes
247        """
248        self._ensure_collected()
249        return self._ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables: list[sqlglot.expressions.query.Subquery]
251    @property
252    def derived_tables(self) -> list[exp.Subquery]:
253        """
254        List of derived tables in this scope.
255
256        For example:
257            SELECT * FROM (SELECT ...) <- that's a derived table
258
259        Returns:
260            list[exp.Subquery]: derived tables
261        """
262        self._ensure_collected()
263        return self._derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs: list[sqlglot.expressions.query.UDTF]
265    @property
266    def udtfs(self) -> list[exp.UDTF]:
267        """
268        List of "User Defined Tabular Functions" in this scope.
269
270        Returns:
271            list[exp.UDTF]: UDTFs
272        """
273        self._ensure_collected()
274        return self._udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

276    @property
277    def subqueries(self) -> list[exp.Select | exp.SetOperation]:
278        """
279        List of subqueries in this scope.
280
281        For example:
282            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
283
284        Returns:
285            list[exp.Select | exp.SetOperation]: subqueries
286        """
287        self._ensure_collected()
288        return self._subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Select | exp.SetOperation]: subqueries

290    @property
291    def stars(self) -> list[exp.Column | exp.Dot]:
292        """
293        List of star expressions (columns or dots) in this scope.
294        """
295        self._ensure_collected()
296        return self._stars

List of star expressions (columns or dots) in this scope.

column_index: set[int]
298    @property
299    def column_index(self) -> set[int]:
300        """
301        Set of column object IDs that belong to this scope's expression.
302        """
303        self._ensure_collected()
304        return self._column_index

Set of column object IDs that belong to this scope's expression.

columns: list[sqlglot.expressions.core.Column]
306    @property
307    def columns(self) -> list[exp.Column]:
308        """
309        List of columns in this scope.
310
311        Returns:
312            list[exp.Column]: Column instances in this scope, plus any
313                Columns that reference this scope from correlated subqueries.
314        """
315        if self._columns is None:
316            self._ensure_collected()
317            columns = self._raw_columns
318
319            external_columns = [
320                column
321                for scope in itertools.chain(
322                    self.subquery_scopes,
323                    self.udtf_scopes,
324                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
325                )
326                for column in scope.external_columns
327            ]
328
329            expr = self.expression
330            named_selects = set(expr.named_selects) if isinstance(expr, exp.Query) else set()
331
332            self._columns = []
333            for column in columns + external_columns:
334                ancestor = column.find_ancestor(
335                    exp.Select,
336                    exp.Qualify,
337                    exp.Order,
338                    exp.Having,
339                    exp.Hint,
340                    exp.Table,
341                    exp.Star,
342                    exp.Distinct,
343                )
344                if (
345                    not ancestor
346                    or column.text("table")
347                    or isinstance(ancestor, exp.Select)
348                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
349                    or (
350                        isinstance(ancestor, (exp.Order, exp.Distinct))
351                        and (
352                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
353                            or not isinstance(ancestor.parent, exp.Select)
354                            or column.name not in named_selects
355                        )
356                    )
357                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
358                ):
359                    self._columns.append(column)
360
361        return self._columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

table_columns: list[sqlglot.expressions.query.TableColumn]
363    @property
364    def table_columns(self) -> list[exp.TableColumn]:
365        self._ensure_collected()
366        return self._table_columns
selected_sources: dict[str, tuple[sqlglot.expressions.query.Selectable, sqlglot.expressions.query.Table | Scope]]
368    @property
369    def selected_sources(self) -> dict[str, tuple[exp.Selectable, exp.Table | Scope]]:
370        """
371        Mapping of nodes and sources that are actually selected from in this scope.
372
373        That is, all tables in a schema are selectable at any point. But a
374        table only becomes a selected source if it's included in a FROM or JOIN clause.
375
376        Returns:
377            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
378        """
379        if self._selected_sources is None:
380            result: dict[str, tuple[exp.Selectable, exp.Table | Scope]] = {}
381
382            for name, node in self.references:
383                if name in self._semi_anti_join_tables:
384                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
385                    # selected source
386                    continue
387
388                if name in result:
389                    raise OptimizeError(f"Alias already used: {name}")
390                if name in self.sources:
391                    result[name] = (node, self.sources[name])
392
393            self._selected_sources = result
394        return self._selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: list[tuple[str, sqlglot.expressions.query.Selectable]]
396    @property
397    def references(self) -> list[tuple[str, exp.Selectable]]:
398        if self._references is None:
399            self._references = []
400
401            for table in self.tables:
402                self._references.append((table.alias_or_name, table))
403            for _expr in itertools.chain(self.derived_tables, self.udtfs):
404                # TODO (mypyc): rebind to exp.Expr to avoid DerivedTable trait vtable dispatch
405                expression: exp.Expr = _expr
406                self._references.append(
407                    (
408                        _get_source_alias(expression),
409                        (
410                            expression if expression.args.get("pivots") else expression.unnest()
411                        ).assert_is(exp.Selectable),
412                    )
413                )
414
415        return self._references
external_columns: list[sqlglot.expressions.core.Column]
417    @property
418    def external_columns(self) -> list[exp.Column]:
419        """
420        Columns that appear to reference sources in outer scopes.
421
422        Returns:
423            list[exp.Column]: Column instances that don't reference sources in the current scope.
424        """
425        if self._external_columns is None:
426            if isinstance(self.expression, exp.SetOperation):
427                left, right = self.union_scopes
428                self._external_columns = left.external_columns + right.external_columns
429            else:
430                local_source_names = {name for name, _ in self.references}
431                self._external_columns = [
432                    c
433                    for c in self.columns
434                    if c.text("table") not in local_source_names
435                    and c.text("table") not in self.semi_or_anti_join_tables
436                ]
437
438        return self._external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

local_columns: list[sqlglot.expressions.core.Column]
440    @property
441    def local_columns(self) -> list[exp.Column]:
442        """
443        Columns in this scope that are not external.
444
445        Returns:
446            list[exp.Column]: Column instances that reference sources in the current scope.
447        """
448        if self._local_columns is None:
449            external_columns = set(self.external_columns)
450            self._local_columns = [c for c in self.columns if c not in external_columns]
451
452        return self._local_columns

Columns in this scope that are not external.

Returns:

list[exp.Column]: Column instances that reference sources in the current scope.

unqualified_columns: list[sqlglot.expressions.core.Column]
454    @property
455    def unqualified_columns(self) -> list[exp.Column]:
456        """
457        Unqualified columns in the current scope.
458
459        Returns:
460             list[exp.Column]: Unqualified columns
461        """
462        return [c for c in self.columns if not c.text("table")]

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints: list[sqlglot.expressions.core.JoinHint]
464    @property
465    def join_hints(self) -> list[exp.JoinHint]:
466        """
467        Hints that exist in the scope that reference tables
468
469        Returns:
470            list[exp.JoinHint]: Join hints that are referenced within the scope
471        """
472        self._ensure_collected()
473        return self._join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots: list[sqlglot.expressions.query.Pivot]
475    @property
476    def pivots(self) -> list[exp.Pivot]:
477        if self._pivots is None:
478            self._pivots = [
479                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
480            ]
481
482        return self._pivots
semi_or_anti_join_tables: set[str]
484    @property
485    def semi_or_anti_join_tables(self) -> set[str]:
486        self._ensure_collected()
487        return self._semi_anti_join_tables
def source_columns(self, source_name: str) -> list[sqlglot.expressions.core.Column]:
489    def source_columns(self, source_name: str) -> list[exp.Column]:
490        """
491        Get all columns in the current scope for a particular source.
492
493        Args:
494            source_name (str): Name of the source
495        Returns:
496            list[exp.Column]: Column instances that reference `source_name`
497        """
498        return [column for column in self.columns if column.text("table") == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery: bool
500    @property
501    def is_subquery(self) -> bool:
502        """Determine if this scope is a subquery"""
503        return self.scope_type == ScopeType.SUBQUERY

Determine if this scope is a subquery

is_derived_table: bool
505    @property
506    def is_derived_table(self) -> bool:
507        """Determine if this scope is a derived table"""
508        return self.scope_type == ScopeType.DERIVED_TABLE

Determine if this scope is a derived table

is_union: bool
510    @property
511    def is_union(self) -> bool:
512        """Determine if this scope is a union"""
513        return self.scope_type == ScopeType.UNION

Determine if this scope is a union

is_cte: bool
515    @property
516    def is_cte(self) -> bool:
517        """Determine if this scope is a common table expression"""
518        return self.scope_type == ScopeType.CTE

Determine if this scope is a common table expression

is_root: bool
520    @property
521    def is_root(self) -> bool:
522        """Determine if this is the root scope"""
523        return self.scope_type == ScopeType.ROOT

Determine if this is the root scope

is_udtf: bool
525    @property
526    def is_udtf(self) -> bool:
527        """Determine if this scope is a UDTF (User Defined Table Function)"""
528        return self.scope_type == ScopeType.UDTF

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery: bool
530    @property
531    def is_correlated_subquery(self) -> bool:
532        """Determine if this scope is a correlated subquery"""
533        return bool(self.can_be_correlated and self.external_columns)

Determine if this scope is a correlated subquery

def rename_source(self, old_name: str | None, new_name: str) -> None:
535    def rename_source(self, old_name: str | None, new_name: str) -> None:
536        """Rename a source in this scope"""
537        old_name = old_name or ""
538        if old_name in self.sources:
539            self.sources[new_name] = self.sources.pop(old_name)

Rename a source in this scope

def add_source( self, name: str, source: sqlglot.expressions.query.Table | Scope) -> None:
541    def add_source(self, name: str, source: exp.Table | Scope) -> None:
542        """Add a source to this scope"""
543        self.sources[name] = source
544        self.clear_cache()

Add a source to this scope

def remove_source(self, name: str) -> None:
546    def remove_source(self, name: str) -> None:
547        """Remove a source from this scope"""
548        self.sources.pop(name, None)
549        self.clear_cache()

Remove a source from this scope

def traverse(self) -> Iterator[Scope]:
554    def traverse(self) -> Iterator[Scope]:
555        """
556        Traverse the scope tree from this node.
557
558        Yields:
559            Scope: scope instances in depth-first-search post-order
560        """
561        stack: list[Scope] = [self]
562        result: list[Scope] = []
563        while stack:
564            scope = stack.pop()
565            result.append(scope)
566            stack.extend(
567                itertools.chain(
568                    scope.cte_scopes,
569                    scope.union_scopes,
570                    scope.table_scopes,
571                    scope.subquery_scopes,
572                )
573            )
574
575        yield from reversed(result)

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self) -> dict[int, int]:
577    def ref_count(self) -> dict[int, int]:
578        """
579        Count the number of times each scope in this tree is referenced.
580
581        Returns:
582            dict[int, int]: Mapping of Scope instance ID to reference count
583        """
584        scope_ref_count: dict[int, int] = defaultdict(int)
585
586        for scope in self.traverse():
587            for _, source in scope.selected_sources.values():
588                scope_ref_count[id(source)] += 1
589
590            for name in scope._semi_anti_join_tables:
591                # semi/anti join sources are not actually selected but we still need to
592                # increment their ref count to avoid them being optimized away
593                if name in scope.sources:
594                    scope_ref_count[id(scope.sources[name])] += 1
595
596        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.core.Expr) -> list[Scope]:
599def traverse_scope(expression: exp.Expr) -> list[Scope]:
600    """
601    Traverse an expression by its "scopes".
602
603    "Scope" represents the current context of a Select statement.
604
605    This is helpful for optimizing queries, where we need more information than
606    the expression tree itself. For example, we might care about the source
607    names within a subquery. Returns a list because a generator could result in
608    incomplete properties which is confusing.
609
610    Examples:
611        >>> import sqlglot
612        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
613        >>> scopes = traverse_scope(expression)
614        >>> scopes[0].expression.sql(), list(scopes[0].sources)
615        ('SELECT a FROM x', ['x'])
616        >>> scopes[1].expression.sql(), list(scopes[1].sources)
617        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
618
619    Args:
620        expression: Expr to traverse
621
622    Returns:
623        A list of the created scope instances
624    """
625    if isinstance(expression, TRAVERSABLES):
626        return list(_traverse_scope(Scope(expression)))
627    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression: Expr to traverse
Returns:

A list of the created scope instances

def build_scope( expression: sqlglot.expressions.core.Expr) -> Scope | None:
630def build_scope(expression: exp.Expr) -> Scope | None:
631    """
632    Build a scope tree.
633
634    Args:
635        expression: Expr to build the scope tree for.
636
637    Returns:
638        The root scope
639    """
640    return seq_get(traverse_scope(expression), -1)

Build a scope tree.

Arguments:
  • expression: Expr to build the scope tree for.
Returns:

The root scope

def walk_in_scope( expression: sqlglot.expressions.core.Expr, prune: Optional[Callable[[sqlglot.expressions.core.Expr], bool]] = None) -> Iterator[sqlglot.expressions.core.Expr]:
915def walk_in_scope(
916    expression: exp.Expr,
917    prune: t.Callable[[exp.Expr], bool] | None = None,
918) -> Iterator[exp.Expr]:
919    """
920    Returns a generator object which visits all nodes in the syntrax tree, stopping at
921    nodes that start child scopes. This does a custom DFS traversal rather than using
922    expression.walk() because the nested generators aren't optimized in mypyc.
923
924    Args:
925        expression:
926        prune: callable that returns True if
927            the generator should stop traversing this branch of the tree.
928
929    Yields:
930        exp.Expr: each node in scope
931    """
932    stack: list[exp.Expr] = [expression]
933
934    while stack:
935        node = stack.pop()
936
937        yield node
938
939        if node is not expression and (
940            isinstance(node, exp.CTE)
941            or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node))
942            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
943            or isinstance(node, exp.UNWRAPPED_QUERIES)
944        ):
945            if isinstance(node, (exp.Subquery, exp.UDTF)):
946                for key in ("joins", "laterals", "pivots"):
947                    for arg in node.args.get(key) or []:
948                        yield from walk_in_scope(arg)
949            continue
950
951        if prune and prune(node):
952            continue
953
954        for vs in reversed(node.args.values()):
955            if isinstance(vs, list):
956                for v in reversed(vs):
957                    if isinstance(v, exp.Expr):
958                        stack.append(v)
959            elif isinstance(vs, exp.Expr):
960                stack.append(vs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes. This does a custom DFS traversal rather than using expression.walk() because the nested generators aren't optimized in mypyc.

Arguments:
  • expression:
  • prune: callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

exp.Expr: each node in scope

def find_all_in_scope( expression: sqlglot.expressions.core.Expr, *expression_types: type[~E]) -> Iterator[~E]:
963def find_all_in_scope(
964    expression: exp.Expr,
965    *expression_types: Type[E],
966) -> Iterator[E]:
967    """
968    Returns a generator object which visits all nodes in this scope and only yields those that
969    match at least one of the specified expression types.
970
971    This does NOT traverse into subscopes.
972
973    Args:
974        expression: the expression to search.
975        expression_types: the expression type(s) to match.
976
977    Yields:
978        The matching nodes.
979    """
980    for node in walk_in_scope(expression):
981        if isinstance(node, expression_types):
982            yield node

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression: the expression to search.
  • expression_types: the expression type(s) to match.
Yields:

The matching nodes.

def find_in_scope( expression: sqlglot.expressions.core.Expr, *expression_types: type[~E]) -> Optional[~E]:
 985def find_in_scope(
 986    expression: exp.Expr,
 987    *expression_types: Type[E],
 988) -> E | None:
 989    """
 990    Returns the first node in this scope which matches at least one of the specified types.
 991
 992    This does NOT traverse into subscopes.
 993
 994    Args:
 995        expression: the expression to search.
 996        expression_types: the expression type(s) to match.
 997
 998    Returns:
 999        The node which matches the criteria or None if no node matching
1000        the criteria was found.
1001    """
1002    return next(find_all_in_scope(expression, *expression_types), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression: the expression to search.
  • expression_types: the expression type(s) to match.
Returns:

The node which matches the criteria or None if no node matching the criteria was found.