Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name, seq_get
 12
 13logger = logging.getLogger("sqlglot")
 14
 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
 16
 17
 18class ScopeType(Enum):
 19    ROOT = auto()
 20    SUBQUERY = auto()
 21    DERIVED_TABLE = auto()
 22    CTE = auto()
 23    UNION = auto()
 24    UDTF = auto()
 25
 26
 27class Scope:
 28    """
 29    Selection scope.
 30
 31    Attributes:
 32        expression (exp.Select|exp.SetOperation): Root expression of this scope
 33        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 34            a Table expression or another Scope instance. For example:
 35                SELECT * FROM x                     {"x": Table(this="x")}
 36                SELECT * FROM x AS y                {"y": Table(this="x")}
 37                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 38        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 39            For example:
 40                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 41            The LATERAL VIEW EXPLODE gets x as a source.
 42        cte_sources (dict[str, Scope]): Sources from CTES
 43        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 44            defines a column list for the alias of this scope, this is that list of columns.
 45            For example:
 46                SELECT * FROM (SELECT ...) AS y(col1, col2)
 47            The inner query would have `["col1", "col2"]` for its `outer_columns`
 48        parent (Scope): Parent scope
 49        scope_type (ScopeType): Type of this scope, relative to it's parent
 50        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 51        cte_scopes (list[Scope]): List of all child scopes for CTEs
 52        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 53        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 54        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 55        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 56            a list of the left and right child scopes.
 57    """
 58
 59    def __init__(
 60        self,
 61        expression,
 62        sources=None,
 63        outer_columns=None,
 64        parent=None,
 65        scope_type=ScopeType.ROOT,
 66        lateral_sources=None,
 67        cte_sources=None,
 68        can_be_correlated=None,
 69    ):
 70        self.expression = expression
 71        self.sources = sources or {}
 72        self.lateral_sources = lateral_sources or {}
 73        self.cte_sources = cte_sources or {}
 74        self.sources.update(self.lateral_sources)
 75        self.sources.update(self.cte_sources)
 76        self.outer_columns = outer_columns or []
 77        self.parent = parent
 78        self.scope_type = scope_type
 79        self.subquery_scopes = []
 80        self.derived_table_scopes = []
 81        self.table_scopes = []
 82        self.cte_scopes = []
 83        self.union_scopes = []
 84        self.udtf_scopes = []
 85        self.can_be_correlated = can_be_correlated
 86        self.clear_cache()
 87
 88    def clear_cache(self):
 89        self._collected = False
 90        self._raw_columns = None
 91        self._table_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._local_columns = None
102        self._join_hints = None
103        self._pivots = None
104        self._references = None
105        self._semi_anti_join_tables = None
106        self._column_index = None
107
108    def branch(
109        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
110    ):
111        """Branch from the current scope to a new, inner scope"""
112        return Scope(
113            expression=expression.unnest(),
114            sources=sources.copy() if sources else None,
115            parent=self,
116            scope_type=scope_type,
117            cte_sources={**self.cte_sources, **(cte_sources or {})},
118            lateral_sources=lateral_sources.copy() if lateral_sources else None,
119            can_be_correlated=self.can_be_correlated
120            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
121            **kwargs,
122        )
123
124    def _collect(self):
125        self._tables = []
126        self._ctes = []
127        self._subqueries = []
128        self._derived_tables = []
129        self._udtfs = []
130        self._raw_columns = []
131        self._table_columns = []
132        self._stars = []
133        self._join_hints = []
134        self._semi_anti_join_tables = set()
135        self._column_index = set()
136
137        for node in self.walk(bfs=False):
138            if node is self.expression:
139                continue
140
141            node_type = type(node)
142
143            if node_type is exp.Dot and node.is_star:
144                self._stars.append(node)
145            elif node_type is exp.Column:
146                self._column_index.add(id(node))
147
148                if type(node.this) is exp.Star:
149                    self._stars.append(node)
150                else:
151                    self._raw_columns.append(node)
152            elif node_type is exp.Table and type(node.parent) is not exp.JoinHint:
153                parent = node.parent
154                if type(parent) is exp.Join and parent.is_semi_or_anti_join:
155                    self._semi_anti_join_tables.add(node.alias_or_name)
156
157                self._tables.append(node)
158            elif node_type is exp.JoinHint:
159                self._join_hints.append(node)
160            elif isinstance(node, exp.UDTF):
161                self._udtfs.append(node)
162            elif node_type is exp.CTE:
163                self._ctes.append(node)
164            elif _is_derived_table(node) and _is_from_or_join(node):
165                self._derived_tables.append(node)
166            elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node):
167                self._subqueries.append(node)
168            elif node_type is exp.TableColumn:
169                self._table_columns.append(node)
170
171        self._collected = True
172
173    def _ensure_collected(self):
174        if not self._collected:
175            self._collect()
176
177    def walk(self, bfs=True, prune=None):
178        return walk_in_scope(self.expression, bfs=bfs, prune=None)
179
180    def find(self, *expression_types, bfs=True):
181        return find_in_scope(self.expression, expression_types, bfs=bfs)
182
183    def find_all(self, *expression_types, bfs=True):
184        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
185
186    def replace(self, old, new):
187        """
188        Replace `old` with `new`.
189
190        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
191
192        Args:
193            old (exp.Expression): old node
194            new (exp.Expression): new node
195        """
196        old.replace(new)
197        self.clear_cache()
198
199    @property
200    def tables(self):
201        """
202        List of tables in this scope.
203
204        Returns:
205            list[exp.Table]: tables
206        """
207        self._ensure_collected()
208        return self._tables
209
210    @property
211    def ctes(self):
212        """
213        List of CTEs in this scope.
214
215        Returns:
216            list[exp.CTE]: ctes
217        """
218        self._ensure_collected()
219        return self._ctes
220
221    @property
222    def derived_tables(self):
223        """
224        List of derived tables in this scope.
225
226        For example:
227            SELECT * FROM (SELECT ...) <- that's a derived table
228
229        Returns:
230            list[exp.Subquery]: derived tables
231        """
232        self._ensure_collected()
233        return self._derived_tables
234
235    @property
236    def udtfs(self):
237        """
238        List of "User Defined Tabular Functions" in this scope.
239
240        Returns:
241            list[exp.UDTF]: UDTFs
242        """
243        self._ensure_collected()
244        return self._udtfs
245
246    @property
247    def subqueries(self):
248        """
249        List of subqueries in this scope.
250
251        For example:
252            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
253
254        Returns:
255            list[exp.Select | exp.SetOperation]: subqueries
256        """
257        self._ensure_collected()
258        return self._subqueries
259
260    @property
261    def stars(self) -> t.List[exp.Column | exp.Dot]:
262        """
263        List of star expressions (columns or dots) in this scope.
264        """
265        self._ensure_collected()
266        return self._stars
267
268    @property
269    def column_index(self) -> t.Set[int]:
270        """
271        Set of column object IDs that belong to this scope's expression.
272        """
273        self._ensure_collected()
274        return self._column_index
275
276    @property
277    def columns(self):
278        """
279        List of columns in this scope.
280
281        Returns:
282            list[exp.Column]: Column instances in this scope, plus any
283                Columns that reference this scope from correlated subqueries.
284        """
285        if self._columns is None:
286            self._ensure_collected()
287            columns = self._raw_columns
288
289            external_columns = [
290                column
291                for scope in itertools.chain(
292                    self.subquery_scopes,
293                    self.udtf_scopes,
294                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
295                )
296                for column in scope.external_columns
297            ]
298
299            named_selects = set(self.expression.named_selects)
300
301            self._columns = []
302            for column in columns + external_columns:
303                ancestor = column.find_ancestor(
304                    exp.Select,
305                    exp.Qualify,
306                    exp.Order,
307                    exp.Having,
308                    exp.Hint,
309                    exp.Table,
310                    exp.Star,
311                    exp.Distinct,
312                )
313                if (
314                    not ancestor
315                    or column.table
316                    or isinstance(ancestor, exp.Select)
317                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
318                    or (
319                        isinstance(ancestor, (exp.Order, exp.Distinct))
320                        and (
321                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
322                            or not isinstance(ancestor.parent, exp.Select)
323                            or column.name not in named_selects
324                        )
325                    )
326                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
327                ):
328                    self._columns.append(column)
329
330        return self._columns
331
332    @property
333    def table_columns(self):
334        if self._table_columns is None:
335            self._ensure_collected()
336
337        return self._table_columns
338
339    @property
340    def selected_sources(self):
341        """
342        Mapping of nodes and sources that are actually selected from in this scope.
343
344        That is, all tables in a schema are selectable at any point. But a
345        table only becomes a selected source if it's included in a FROM or JOIN clause.
346
347        Returns:
348            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
349        """
350        if self._selected_sources is None:
351            result = {}
352
353            for name, node in self.references:
354                if name in self._semi_anti_join_tables:
355                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
356                    # selected source
357                    continue
358
359                if name in result:
360                    raise OptimizeError(f"Alias already used: {name}")
361                if name in self.sources:
362                    result[name] = (node, self.sources[name])
363
364            self._selected_sources = result
365        return self._selected_sources
366
367    @property
368    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
369        if self._references is None:
370            self._references = []
371
372            for table in self.tables:
373                self._references.append((table.alias_or_name, table))
374            for expression in itertools.chain(self.derived_tables, self.udtfs):
375                self._references.append(
376                    (
377                        _get_source_alias(expression),
378                        expression if expression.args.get("pivots") else expression.unnest(),
379                    )
380                )
381
382        return self._references
383
384    @property
385    def external_columns(self):
386        """
387        Columns that appear to reference sources in outer scopes.
388
389        Returns:
390            list[exp.Column]: Column instances that don't reference sources in the current scope.
391        """
392        if self._external_columns is None:
393            if isinstance(self.expression, exp.SetOperation):
394                left, right = self.union_scopes
395                self._external_columns = left.external_columns + right.external_columns
396            else:
397                self._external_columns = [
398                    c
399                    for c in self.columns
400                    if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables
401                ]
402
403        return self._external_columns
404
405    @property
406    def local_columns(self):
407        """
408        Columns in this scope that are not external.
409
410        Returns:
411            list[exp.Column]: Column instances that reference sources in the current scope.
412        """
413        if self._local_columns is None:
414            external_columns = set(self.external_columns)
415            self._local_columns = [c for c in self.columns if c not in external_columns]
416
417        return self._local_columns
418
419    @property
420    def unqualified_columns(self):
421        """
422        Unqualified columns in the current scope.
423
424        Returns:
425             list[exp.Column]: Unqualified columns
426        """
427        return [c for c in self.columns if not c.table]
428
429    @property
430    def join_hints(self):
431        """
432        Hints that exist in the scope that reference tables
433
434        Returns:
435            list[exp.JoinHint]: Join hints that are referenced within the scope
436        """
437        if self._join_hints is None:
438            return []
439        return self._join_hints
440
441    @property
442    def pivots(self):
443        if not self._pivots:
444            self._pivots = [
445                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
446            ]
447
448        return self._pivots
449
450    @property
451    def semi_or_anti_join_tables(self):
452        return self._semi_anti_join_tables or set()
453
454    def source_columns(self, source_name):
455        """
456        Get all columns in the current scope for a particular source.
457
458        Args:
459            source_name (str): Name of the source
460        Returns:
461            list[exp.Column]: Column instances that reference `source_name`
462        """
463        return [column for column in self.columns if column.table == source_name]
464
465    @property
466    def is_subquery(self):
467        """Determine if this scope is a subquery"""
468        return self.scope_type == ScopeType.SUBQUERY
469
470    @property
471    def is_derived_table(self):
472        """Determine if this scope is a derived table"""
473        return self.scope_type == ScopeType.DERIVED_TABLE
474
475    @property
476    def is_union(self):
477        """Determine if this scope is a union"""
478        return self.scope_type == ScopeType.UNION
479
480    @property
481    def is_cte(self):
482        """Determine if this scope is a common table expression"""
483        return self.scope_type == ScopeType.CTE
484
485    @property
486    def is_root(self):
487        """Determine if this is the root scope"""
488        return self.scope_type == ScopeType.ROOT
489
490    @property
491    def is_udtf(self):
492        """Determine if this scope is a UDTF (User Defined Table Function)"""
493        return self.scope_type == ScopeType.UDTF
494
495    @property
496    def is_correlated_subquery(self):
497        """Determine if this scope is a correlated subquery"""
498        return bool(self.can_be_correlated and self.external_columns)
499
500    def rename_source(self, old_name, new_name):
501        """Rename a source in this scope"""
502        old_name = old_name or ""
503        if old_name in self.sources:
504            self.sources[new_name] = self.sources.pop(old_name)
505
506    def add_source(self, name, source):
507        """Add a source to this scope"""
508        self.sources[name] = source
509        self.clear_cache()
510
511    def remove_source(self, name):
512        """Remove a source from this scope"""
513        self.sources.pop(name, None)
514        self.clear_cache()
515
516    def __repr__(self):
517        return f"Scope<{self.expression.sql()}>"
518
519    def traverse(self):
520        """
521        Traverse the scope tree from this node.
522
523        Yields:
524            Scope: scope instances in depth-first-search post-order
525        """
526        stack = [self]
527        result = []
528        while stack:
529            scope = stack.pop()
530            result.append(scope)
531            stack.extend(
532                itertools.chain(
533                    scope.cte_scopes,
534                    scope.union_scopes,
535                    scope.table_scopes,
536                    scope.subquery_scopes,
537                )
538            )
539
540        yield from reversed(result)
541
542    def ref_count(self):
543        """
544        Count the number of times each scope in this tree is referenced.
545
546        Returns:
547            dict[int, int]: Mapping of Scope instance ID to reference count
548        """
549        scope_ref_count = defaultdict(lambda: 0)
550
551        for scope in self.traverse():
552            for _, source in scope.selected_sources.values():
553                scope_ref_count[id(source)] += 1
554
555            for name in scope._semi_anti_join_tables:
556                # semi/anti join sources are not actually selected but we still need to
557                # increment their ref count to avoid them being optimized away
558                if name in scope.sources:
559                    scope_ref_count[id(scope.sources[name])] += 1
560
561        return scope_ref_count
562
563
564def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
565    """
566    Traverse an expression by its "scopes".
567
568    "Scope" represents the current context of a Select statement.
569
570    This is helpful for optimizing queries, where we need more information than
571    the expression tree itself. For example, we might care about the source
572    names within a subquery. Returns a list because a generator could result in
573    incomplete properties which is confusing.
574
575    Examples:
576        >>> import sqlglot
577        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
578        >>> scopes = traverse_scope(expression)
579        >>> scopes[0].expression.sql(), list(scopes[0].sources)
580        ('SELECT a FROM x', ['x'])
581        >>> scopes[1].expression.sql(), list(scopes[1].sources)
582        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
583
584    Args:
585        expression: Expression to traverse
586
587    Returns:
588        A list of the created scope instances
589    """
590    if isinstance(expression, TRAVERSABLES):
591        return list(_traverse_scope(Scope(expression)))
592    return []
593
594
595def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
596    """
597    Build a scope tree.
598
599    Args:
600        expression: Expression to build the scope tree for.
601
602    Returns:
603        The root scope
604    """
605    return seq_get(traverse_scope(expression), -1)
606
607
608def _traverse_scope(scope):
609    expression = scope.expression
610
611    if isinstance(expression, exp.Select):
612        yield from _traverse_select(scope)
613    elif isinstance(expression, exp.SetOperation):
614        yield from _traverse_ctes(scope)
615        yield from _traverse_union(scope)
616        return
617    elif isinstance(expression, exp.Subquery):
618        if scope.is_root:
619            yield from _traverse_select(scope)
620        else:
621            yield from _traverse_subqueries(scope)
622    elif isinstance(expression, exp.Table):
623        yield from _traverse_tables(scope)
624    elif isinstance(expression, exp.UDTF):
625        yield from _traverse_udtfs(scope)
626    elif isinstance(expression, exp.DDL):
627        if isinstance(expression.expression, exp.Query):
628            yield from _traverse_ctes(scope)
629            yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
630        return
631    elif isinstance(expression, exp.DML):
632        yield from _traverse_ctes(scope)
633        for query in find_all_in_scope(expression, exp.Query):
634            # This check ensures we don't yield the CTE/nested queries twice
635            if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
636                yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
637        return
638    else:
639        logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
640        return
641
642    yield scope
643
644
645def _traverse_select(scope):
646    yield from _traverse_ctes(scope)
647    yield from _traverse_tables(scope)
648    yield from _traverse_subqueries(scope)
649
650
651def _traverse_union(scope):
652    prev_scope = None
653    union_scope_stack = [scope]
654    expression_stack = [scope.expression.right, scope.expression.left]
655
656    while expression_stack:
657        expression = expression_stack.pop()
658        union_scope = union_scope_stack[-1]
659
660        new_scope = union_scope.branch(
661            expression,
662            outer_columns=union_scope.outer_columns,
663            scope_type=ScopeType.UNION,
664        )
665
666        if isinstance(expression, exp.SetOperation):
667            yield from _traverse_ctes(new_scope)
668
669            union_scope_stack.append(new_scope)
670            expression_stack.extend([expression.right, expression.left])
671            continue
672
673        for scope in _traverse_scope(new_scope):
674            yield scope
675
676        if prev_scope:
677            union_scope_stack.pop()
678            union_scope.union_scopes = [prev_scope, scope]
679            prev_scope = union_scope
680
681            yield union_scope
682        else:
683            prev_scope = scope
684
685
686def _traverse_ctes(scope):
687    sources = {}
688
689    for cte in scope.ctes:
690        cte_name = cte.alias
691
692        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
693        # thus the recursive scope is the first section of the union.
694        with_ = scope.expression.args.get("with_")
695        if with_ and with_.recursive:
696            union = cte.this
697
698            if isinstance(union, exp.SetOperation):
699                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
700
701        child_scope = None
702
703        for child_scope in _traverse_scope(
704            scope.branch(
705                cte.this,
706                cte_sources=sources,
707                outer_columns=cte.alias_column_names,
708                scope_type=ScopeType.CTE,
709            )
710        ):
711            yield child_scope
712
713        # append the final child_scope yielded
714        if child_scope:
715            sources[cte_name] = child_scope
716            scope.cte_scopes.append(child_scope)
717
718    scope.sources.update(sources)
719    scope.cte_sources.update(sources)
720
721
722def _is_derived_table(expression: exp.Subquery) -> bool:
723    """
724    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
725    as it doesn't introduce a new scope. If an alias is present, it shadows all names
726    under the Subquery, so that's one exception to this rule.
727    """
728    return type(expression) is exp.Subquery and bool(
729        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
730    )
731
732
733def _is_from_or_join(expression: exp.Expression) -> bool:
734    """
735    Determine if `expression` is the FROM or JOIN clause of a SELECT statement.
736    """
737    parent = expression.parent
738
739    # Subqueries can be arbitrarily nested
740    while type(parent) is exp.Subquery:
741        parent = parent.parent
742
743    return type(parent) in (exp.From, exp.Join)
744
745
746def _traverse_tables(scope):
747    sources = {}
748
749    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
750    expressions = []
751    from_ = scope.expression.args.get("from_")
752    if from_:
753        expressions.append(from_.this)
754
755    for join in scope.expression.args.get("joins") or []:
756        expressions.append(join.this)
757
758    if isinstance(scope.expression, exp.Table):
759        expressions.append(scope.expression)
760
761    expressions.extend(scope.expression.args.get("laterals") or [])
762
763    for expression in expressions:
764        if isinstance(expression, exp.Final):
765            expression = expression.this
766        if isinstance(expression, exp.Table):
767            table_name = expression.name
768            source_name = expression.alias_or_name
769
770            if table_name in scope.sources and not expression.db:
771                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
772                # it is pivoted, because then we get back a new table and hence a new source.
773                pivots = expression.args.get("pivots")
774                if pivots:
775                    sources[pivots[0].alias] = expression
776                else:
777                    sources[source_name] = scope.sources[table_name]
778            elif source_name in sources:
779                sources[find_new_name(sources, table_name)] = expression
780            else:
781                sources[source_name] = expression
782
783            # Make sure to not include the joins twice
784            if expression is not scope.expression:
785                expressions.extend(join.this for join in expression.args.get("joins") or [])
786
787            continue
788
789        if not isinstance(expression, exp.DerivedTable):
790            continue
791
792        if isinstance(expression, exp.UDTF):
793            lateral_sources = sources
794            scope_type = ScopeType.UDTF
795            scopes = scope.udtf_scopes
796        elif _is_derived_table(expression):
797            lateral_sources = None
798            scope_type = ScopeType.DERIVED_TABLE
799            scopes = scope.derived_table_scopes
800            expressions.extend(join.this for join in expression.args.get("joins") or [])
801        else:
802            # Makes sure we check for possible sources in nested table constructs
803            expressions.append(expression.this)
804            expressions.extend(join.this for join in expression.args.get("joins") or [])
805            continue
806
807        child_scope = None
808
809        for child_scope in _traverse_scope(
810            scope.branch(
811                expression,
812                lateral_sources=lateral_sources,
813                outer_columns=expression.alias_column_names,
814                scope_type=scope_type,
815            )
816        ):
817            yield child_scope
818
819            # Tables without aliases will be set as ""
820            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
821            # Until then, this means that only a single, unaliased derived table is allowed (rather,
822            # the latest one wins.
823            sources[_get_source_alias(expression)] = child_scope
824
825        # append the final child_scope yielded
826        if child_scope:
827            scopes.append(child_scope)
828            scope.table_scopes.append(child_scope)
829
830    scope.sources.update(sources)
831
832
833def _traverse_subqueries(scope):
834    for subquery in scope.subqueries:
835        top = None
836        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
837            yield child_scope
838            top = child_scope
839        scope.subquery_scopes.append(top)
840
841
842def _traverse_udtfs(scope):
843    if isinstance(scope.expression, exp.Unnest):
844        expressions = scope.expression.expressions
845    elif isinstance(scope.expression, exp.Lateral):
846        expressions = [scope.expression.this]
847    else:
848        expressions = []
849
850    sources = {}
851    for expression in expressions:
852        if isinstance(expression, exp.Subquery):
853            top = None
854            for child_scope in _traverse_scope(
855                scope.branch(
856                    expression,
857                    scope_type=ScopeType.SUBQUERY,
858                    outer_columns=expression.alias_column_names,
859                )
860            ):
861                yield child_scope
862                top = child_scope
863                sources[_get_source_alias(expression)] = child_scope
864
865            scope.subquery_scopes.append(top)
866
867    scope.sources.update(sources)
868
869
870def walk_in_scope(expression, bfs=True, prune=None):
871    """
872    Returns a generator object which visits all nodes in the syntrax tree, stopping at
873    nodes that start child scopes.
874
875    Args:
876        expression (exp.Expression):
877        bfs (bool): if set to True the BFS traversal order will be applied,
878            otherwise the DFS traversal will be used instead.
879        prune ((node, parent, arg_key) -> bool): callable that returns True if
880            the generator should stop traversing this branch of the tree.
881
882    Yields:
883        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
884    """
885    # We'll use this variable to pass state into the dfs generator.
886    # Whenever we set it to True, we exclude a subtree from traversal.
887    crossed_scope_boundary = False
888
889    for node in expression.walk(
890        bfs=bfs, prune=lambda n: bool(crossed_scope_boundary or (prune and prune(n)))
891    ):
892        crossed_scope_boundary = False
893
894        yield node
895
896        if node is expression:
897            continue
898
899        node_type = type(node)
900        parent_type = type(node.parent)
901        if (
902            node_type is exp.CTE
903            or (parent_type in (exp.From, exp.Join) and _is_derived_table(node))
904            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
905            or isinstance(node, exp.UNWRAPPED_QUERIES)
906        ):
907            crossed_scope_boundary = True
908
909            if node_type is exp.Subquery or isinstance(node, exp.UDTF):
910                # The following args are not actually in the inner scope, so we should visit them
911                for key in ("joins", "laterals", "pivots"):
912                    for arg in node.args.get(key) or []:
913                        yield from walk_in_scope(arg, bfs=bfs)
914
915
916def find_all_in_scope(expression, expression_types, bfs=True):
917    """
918    Returns a generator object which visits all nodes in this scope and only yields those that
919    match at least one of the specified expression types.
920
921    This does NOT traverse into subscopes.
922
923    Args:
924        expression (exp.Expression):
925        expression_types (tuple[type]|type): the expression type(s) to match.
926        bfs (bool): True to use breadth-first search, False to use depth-first.
927
928    Yields:
929        exp.Expression: nodes
930    """
931    for expression in walk_in_scope(expression, bfs=bfs):
932        if isinstance(expression, tuple(ensure_collection(expression_types))):
933            yield expression
934
935
936def find_in_scope(expression, expression_types, bfs=True):
937    """
938    Returns the first node in this scope which matches at least one of the specified types.
939
940    This does NOT traverse into subscopes.
941
942    Args:
943        expression (exp.Expression):
944        expression_types (tuple[type]|type): the expression type(s) to match.
945        bfs (bool): True to use breadth-first search, False to use depth-first.
946
947    Returns:
948        exp.Expression: the node which matches the criteria or None if no node matching
949        the criteria was found.
950    """
951    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
952
953
954def _get_source_alias(expression):
955    alias_arg = expression.args.get("alias")
956    alias_name = expression.alias
957
958    if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1:
959        alias_name = alias_arg.columns[0].name
960
961    return alias_name
logger = <Logger sqlglot (WARNING)>
TRAVERSABLES = (<class 'sqlglot.expressions.Query'>, <class 'sqlglot.expressions.DDL'>, <class 'sqlglot.expressions.DML'>)
class ScopeType(enum.Enum):
19class ScopeType(Enum):
20    ROOT = auto()
21    SUBQUERY = auto()
22    DERIVED_TABLE = auto()
23    CTE = auto()
24    UNION = auto()
25    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
class Scope:
 28class Scope:
 29    """
 30    Selection scope.
 31
 32    Attributes:
 33        expression (exp.Select|exp.SetOperation): Root expression of this scope
 34        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 35            a Table expression or another Scope instance. For example:
 36                SELECT * FROM x                     {"x": Table(this="x")}
 37                SELECT * FROM x AS y                {"y": Table(this="x")}
 38                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 39        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 40            For example:
 41                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 42            The LATERAL VIEW EXPLODE gets x as a source.
 43        cte_sources (dict[str, Scope]): Sources from CTES
 44        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 45            defines a column list for the alias of this scope, this is that list of columns.
 46            For example:
 47                SELECT * FROM (SELECT ...) AS y(col1, col2)
 48            The inner query would have `["col1", "col2"]` for its `outer_columns`
 49        parent (Scope): Parent scope
 50        scope_type (ScopeType): Type of this scope, relative to it's parent
 51        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 52        cte_scopes (list[Scope]): List of all child scopes for CTEs
 53        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 54        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 55        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 56        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 57            a list of the left and right child scopes.
 58    """
 59
 60    def __init__(
 61        self,
 62        expression,
 63        sources=None,
 64        outer_columns=None,
 65        parent=None,
 66        scope_type=ScopeType.ROOT,
 67        lateral_sources=None,
 68        cte_sources=None,
 69        can_be_correlated=None,
 70    ):
 71        self.expression = expression
 72        self.sources = sources or {}
 73        self.lateral_sources = lateral_sources or {}
 74        self.cte_sources = cte_sources or {}
 75        self.sources.update(self.lateral_sources)
 76        self.sources.update(self.cte_sources)
 77        self.outer_columns = outer_columns or []
 78        self.parent = parent
 79        self.scope_type = scope_type
 80        self.subquery_scopes = []
 81        self.derived_table_scopes = []
 82        self.table_scopes = []
 83        self.cte_scopes = []
 84        self.union_scopes = []
 85        self.udtf_scopes = []
 86        self.can_be_correlated = can_be_correlated
 87        self.clear_cache()
 88
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._local_columns = None
103        self._join_hints = None
104        self._pivots = None
105        self._references = None
106        self._semi_anti_join_tables = None
107        self._column_index = None
108
109    def branch(
110        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
111    ):
112        """Branch from the current scope to a new, inner scope"""
113        return Scope(
114            expression=expression.unnest(),
115            sources=sources.copy() if sources else None,
116            parent=self,
117            scope_type=scope_type,
118            cte_sources={**self.cte_sources, **(cte_sources or {})},
119            lateral_sources=lateral_sources.copy() if lateral_sources else None,
120            can_be_correlated=self.can_be_correlated
121            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
122            **kwargs,
123        )
124
125    def _collect(self):
126        self._tables = []
127        self._ctes = []
128        self._subqueries = []
129        self._derived_tables = []
130        self._udtfs = []
131        self._raw_columns = []
132        self._table_columns = []
133        self._stars = []
134        self._join_hints = []
135        self._semi_anti_join_tables = set()
136        self._column_index = set()
137
138        for node in self.walk(bfs=False):
139            if node is self.expression:
140                continue
141
142            node_type = type(node)
143
144            if node_type is exp.Dot and node.is_star:
145                self._stars.append(node)
146            elif node_type is exp.Column:
147                self._column_index.add(id(node))
148
149                if type(node.this) is exp.Star:
150                    self._stars.append(node)
151                else:
152                    self._raw_columns.append(node)
153            elif node_type is exp.Table and type(node.parent) is not exp.JoinHint:
154                parent = node.parent
155                if type(parent) is exp.Join and parent.is_semi_or_anti_join:
156                    self._semi_anti_join_tables.add(node.alias_or_name)
157
158                self._tables.append(node)
159            elif node_type is exp.JoinHint:
160                self._join_hints.append(node)
161            elif isinstance(node, exp.UDTF):
162                self._udtfs.append(node)
163            elif node_type is exp.CTE:
164                self._ctes.append(node)
165            elif _is_derived_table(node) and _is_from_or_join(node):
166                self._derived_tables.append(node)
167            elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node):
168                self._subqueries.append(node)
169            elif node_type is exp.TableColumn:
170                self._table_columns.append(node)
171
172        self._collected = True
173
174    def _ensure_collected(self):
175        if not self._collected:
176            self._collect()
177
178    def walk(self, bfs=True, prune=None):
179        return walk_in_scope(self.expression, bfs=bfs, prune=None)
180
181    def find(self, *expression_types, bfs=True):
182        return find_in_scope(self.expression, expression_types, bfs=bfs)
183
184    def find_all(self, *expression_types, bfs=True):
185        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
186
187    def replace(self, old, new):
188        """
189        Replace `old` with `new`.
190
191        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
192
193        Args:
194            old (exp.Expression): old node
195            new (exp.Expression): new node
196        """
197        old.replace(new)
198        self.clear_cache()
199
200    @property
201    def tables(self):
202        """
203        List of tables in this scope.
204
205        Returns:
206            list[exp.Table]: tables
207        """
208        self._ensure_collected()
209        return self._tables
210
211    @property
212    def ctes(self):
213        """
214        List of CTEs in this scope.
215
216        Returns:
217            list[exp.CTE]: ctes
218        """
219        self._ensure_collected()
220        return self._ctes
221
222    @property
223    def derived_tables(self):
224        """
225        List of derived tables in this scope.
226
227        For example:
228            SELECT * FROM (SELECT ...) <- that's a derived table
229
230        Returns:
231            list[exp.Subquery]: derived tables
232        """
233        self._ensure_collected()
234        return self._derived_tables
235
236    @property
237    def udtfs(self):
238        """
239        List of "User Defined Tabular Functions" in this scope.
240
241        Returns:
242            list[exp.UDTF]: UDTFs
243        """
244        self._ensure_collected()
245        return self._udtfs
246
247    @property
248    def subqueries(self):
249        """
250        List of subqueries in this scope.
251
252        For example:
253            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
254
255        Returns:
256            list[exp.Select | exp.SetOperation]: subqueries
257        """
258        self._ensure_collected()
259        return self._subqueries
260
261    @property
262    def stars(self) -> t.List[exp.Column | exp.Dot]:
263        """
264        List of star expressions (columns or dots) in this scope.
265        """
266        self._ensure_collected()
267        return self._stars
268
269    @property
270    def column_index(self) -> t.Set[int]:
271        """
272        Set of column object IDs that belong to this scope's expression.
273        """
274        self._ensure_collected()
275        return self._column_index
276
277    @property
278    def columns(self):
279        """
280        List of columns in this scope.
281
282        Returns:
283            list[exp.Column]: Column instances in this scope, plus any
284                Columns that reference this scope from correlated subqueries.
285        """
286        if self._columns is None:
287            self._ensure_collected()
288            columns = self._raw_columns
289
290            external_columns = [
291                column
292                for scope in itertools.chain(
293                    self.subquery_scopes,
294                    self.udtf_scopes,
295                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
296                )
297                for column in scope.external_columns
298            ]
299
300            named_selects = set(self.expression.named_selects)
301
302            self._columns = []
303            for column in columns + external_columns:
304                ancestor = column.find_ancestor(
305                    exp.Select,
306                    exp.Qualify,
307                    exp.Order,
308                    exp.Having,
309                    exp.Hint,
310                    exp.Table,
311                    exp.Star,
312                    exp.Distinct,
313                )
314                if (
315                    not ancestor
316                    or column.table
317                    or isinstance(ancestor, exp.Select)
318                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
319                    or (
320                        isinstance(ancestor, (exp.Order, exp.Distinct))
321                        and (
322                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
323                            or not isinstance(ancestor.parent, exp.Select)
324                            or column.name not in named_selects
325                        )
326                    )
327                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
328                ):
329                    self._columns.append(column)
330
331        return self._columns
332
333    @property
334    def table_columns(self):
335        if self._table_columns is None:
336            self._ensure_collected()
337
338        return self._table_columns
339
340    @property
341    def selected_sources(self):
342        """
343        Mapping of nodes and sources that are actually selected from in this scope.
344
345        That is, all tables in a schema are selectable at any point. But a
346        table only becomes a selected source if it's included in a FROM or JOIN clause.
347
348        Returns:
349            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
350        """
351        if self._selected_sources is None:
352            result = {}
353
354            for name, node in self.references:
355                if name in self._semi_anti_join_tables:
356                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
357                    # selected source
358                    continue
359
360                if name in result:
361                    raise OptimizeError(f"Alias already used: {name}")
362                if name in self.sources:
363                    result[name] = (node, self.sources[name])
364
365            self._selected_sources = result
366        return self._selected_sources
367
368    @property
369    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
370        if self._references is None:
371            self._references = []
372
373            for table in self.tables:
374                self._references.append((table.alias_or_name, table))
375            for expression in itertools.chain(self.derived_tables, self.udtfs):
376                self._references.append(
377                    (
378                        _get_source_alias(expression),
379                        expression if expression.args.get("pivots") else expression.unnest(),
380                    )
381                )
382
383        return self._references
384
385    @property
386    def external_columns(self):
387        """
388        Columns that appear to reference sources in outer scopes.
389
390        Returns:
391            list[exp.Column]: Column instances that don't reference sources in the current scope.
392        """
393        if self._external_columns is None:
394            if isinstance(self.expression, exp.SetOperation):
395                left, right = self.union_scopes
396                self._external_columns = left.external_columns + right.external_columns
397            else:
398                self._external_columns = [
399                    c
400                    for c in self.columns
401                    if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables
402                ]
403
404        return self._external_columns
405
406    @property
407    def local_columns(self):
408        """
409        Columns in this scope that are not external.
410
411        Returns:
412            list[exp.Column]: Column instances that reference sources in the current scope.
413        """
414        if self._local_columns is None:
415            external_columns = set(self.external_columns)
416            self._local_columns = [c for c in self.columns if c not in external_columns]
417
418        return self._local_columns
419
420    @property
421    def unqualified_columns(self):
422        """
423        Unqualified columns in the current scope.
424
425        Returns:
426             list[exp.Column]: Unqualified columns
427        """
428        return [c for c in self.columns if not c.table]
429
430    @property
431    def join_hints(self):
432        """
433        Hints that exist in the scope that reference tables
434
435        Returns:
436            list[exp.JoinHint]: Join hints that are referenced within the scope
437        """
438        if self._join_hints is None:
439            return []
440        return self._join_hints
441
442    @property
443    def pivots(self):
444        if not self._pivots:
445            self._pivots = [
446                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
447            ]
448
449        return self._pivots
450
451    @property
452    def semi_or_anti_join_tables(self):
453        return self._semi_anti_join_tables or set()
454
455    def source_columns(self, source_name):
456        """
457        Get all columns in the current scope for a particular source.
458
459        Args:
460            source_name (str): Name of the source
461        Returns:
462            list[exp.Column]: Column instances that reference `source_name`
463        """
464        return [column for column in self.columns if column.table == source_name]
465
466    @property
467    def is_subquery(self):
468        """Determine if this scope is a subquery"""
469        return self.scope_type == ScopeType.SUBQUERY
470
471    @property
472    def is_derived_table(self):
473        """Determine if this scope is a derived table"""
474        return self.scope_type == ScopeType.DERIVED_TABLE
475
476    @property
477    def is_union(self):
478        """Determine if this scope is a union"""
479        return self.scope_type == ScopeType.UNION
480
481    @property
482    def is_cte(self):
483        """Determine if this scope is a common table expression"""
484        return self.scope_type == ScopeType.CTE
485
486    @property
487    def is_root(self):
488        """Determine if this is the root scope"""
489        return self.scope_type == ScopeType.ROOT
490
491    @property
492    def is_udtf(self):
493        """Determine if this scope is a UDTF (User Defined Table Function)"""
494        return self.scope_type == ScopeType.UDTF
495
496    @property
497    def is_correlated_subquery(self):
498        """Determine if this scope is a correlated subquery"""
499        return bool(self.can_be_correlated and self.external_columns)
500
501    def rename_source(self, old_name, new_name):
502        """Rename a source in this scope"""
503        old_name = old_name or ""
504        if old_name in self.sources:
505            self.sources[new_name] = self.sources.pop(old_name)
506
507    def add_source(self, name, source):
508        """Add a source to this scope"""
509        self.sources[name] = source
510        self.clear_cache()
511
512    def remove_source(self, name):
513        """Remove a source from this scope"""
514        self.sources.pop(name, None)
515        self.clear_cache()
516
517    def __repr__(self):
518        return f"Scope<{self.expression.sql()}>"
519
520    def traverse(self):
521        """
522        Traverse the scope tree from this node.
523
524        Yields:
525            Scope: scope instances in depth-first-search post-order
526        """
527        stack = [self]
528        result = []
529        while stack:
530            scope = stack.pop()
531            result.append(scope)
532            stack.extend(
533                itertools.chain(
534                    scope.cte_scopes,
535                    scope.union_scopes,
536                    scope.table_scopes,
537                    scope.subquery_scopes,
538                )
539            )
540
541        yield from reversed(result)
542
543    def ref_count(self):
544        """
545        Count the number of times each scope in this tree is referenced.
546
547        Returns:
548            dict[int, int]: Mapping of Scope instance ID to reference count
549        """
550        scope_ref_count = defaultdict(lambda: 0)
551
552        for scope in self.traverse():
553            for _, source in scope.selected_sources.values():
554                scope_ref_count[id(source)] += 1
555
556            for name in scope._semi_anti_join_tables:
557                # semi/anti join sources are not actually selected but we still need to
558                # increment their ref count to avoid them being optimized away
559                if name in scope.sources:
560                    scope_ref_count[id(scope.sources[name])] += 1
561
562        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.SetOperation): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • cte_sources (dict[str, Scope]): Sources from CTES
  • outer_columns (list[str]): If this is a derived table or CTE, and the outer query defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_columns
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_columns=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None, cte_sources=None, can_be_correlated=None)
60    def __init__(
61        self,
62        expression,
63        sources=None,
64        outer_columns=None,
65        parent=None,
66        scope_type=ScopeType.ROOT,
67        lateral_sources=None,
68        cte_sources=None,
69        can_be_correlated=None,
70    ):
71        self.expression = expression
72        self.sources = sources or {}
73        self.lateral_sources = lateral_sources or {}
74        self.cte_sources = cte_sources or {}
75        self.sources.update(self.lateral_sources)
76        self.sources.update(self.cte_sources)
77        self.outer_columns = outer_columns or []
78        self.parent = parent
79        self.scope_type = scope_type
80        self.subquery_scopes = []
81        self.derived_table_scopes = []
82        self.table_scopes = []
83        self.cte_scopes = []
84        self.union_scopes = []
85        self.udtf_scopes = []
86        self.can_be_correlated = can_be_correlated
87        self.clear_cache()
expression
sources
lateral_sources
cte_sources
outer_columns
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
can_be_correlated
def clear_cache(self):
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._local_columns = None
103        self._join_hints = None
104        self._pivots = None
105        self._references = None
106        self._semi_anti_join_tables = None
107        self._column_index = None
def branch( self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs):
109    def branch(
110        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
111    ):
112        """Branch from the current scope to a new, inner scope"""
113        return Scope(
114            expression=expression.unnest(),
115            sources=sources.copy() if sources else None,
116            parent=self,
117            scope_type=scope_type,
118            cte_sources={**self.cte_sources, **(cte_sources or {})},
119            lateral_sources=lateral_sources.copy() if lateral_sources else None,
120            can_be_correlated=self.can_be_correlated
121            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
122            **kwargs,
123        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
178    def walk(self, bfs=True, prune=None):
179        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
181    def find(self, *expression_types, bfs=True):
182        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
184    def find_all(self, *expression_types, bfs=True):
185        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
187    def replace(self, old, new):
188        """
189        Replace `old` with `new`.
190
191        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
192
193        Args:
194            old (exp.Expression): old node
195            new (exp.Expression): new node
196        """
197        old.replace(new)
198        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables
200    @property
201    def tables(self):
202        """
203        List of tables in this scope.
204
205        Returns:
206            list[exp.Table]: tables
207        """
208        self._ensure_collected()
209        return self._tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes
211    @property
212    def ctes(self):
213        """
214        List of CTEs in this scope.
215
216        Returns:
217            list[exp.CTE]: ctes
218        """
219        self._ensure_collected()
220        return self._ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables
222    @property
223    def derived_tables(self):
224        """
225        List of derived tables in this scope.
226
227        For example:
228            SELECT * FROM (SELECT ...) <- that's a derived table
229
230        Returns:
231            list[exp.Subquery]: derived tables
232        """
233        self._ensure_collected()
234        return self._derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs
236    @property
237    def udtfs(self):
238        """
239        List of "User Defined Tabular Functions" in this scope.
240
241        Returns:
242            list[exp.UDTF]: UDTFs
243        """
244        self._ensure_collected()
245        return self._udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries
247    @property
248    def subqueries(self):
249        """
250        List of subqueries in this scope.
251
252        For example:
253            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
254
255        Returns:
256            list[exp.Select | exp.SetOperation]: subqueries
257        """
258        self._ensure_collected()
259        return self._subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Select | exp.SetOperation]: subqueries

261    @property
262    def stars(self) -> t.List[exp.Column | exp.Dot]:
263        """
264        List of star expressions (columns or dots) in this scope.
265        """
266        self._ensure_collected()
267        return self._stars

List of star expressions (columns or dots) in this scope.

column_index: Set[int]
269    @property
270    def column_index(self) -> t.Set[int]:
271        """
272        Set of column object IDs that belong to this scope's expression.
273        """
274        self._ensure_collected()
275        return self._column_index

Set of column object IDs that belong to this scope's expression.

columns
277    @property
278    def columns(self):
279        """
280        List of columns in this scope.
281
282        Returns:
283            list[exp.Column]: Column instances in this scope, plus any
284                Columns that reference this scope from correlated subqueries.
285        """
286        if self._columns is None:
287            self._ensure_collected()
288            columns = self._raw_columns
289
290            external_columns = [
291                column
292                for scope in itertools.chain(
293                    self.subquery_scopes,
294                    self.udtf_scopes,
295                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
296                )
297                for column in scope.external_columns
298            ]
299
300            named_selects = set(self.expression.named_selects)
301
302            self._columns = []
303            for column in columns + external_columns:
304                ancestor = column.find_ancestor(
305                    exp.Select,
306                    exp.Qualify,
307                    exp.Order,
308                    exp.Having,
309                    exp.Hint,
310                    exp.Table,
311                    exp.Star,
312                    exp.Distinct,
313                )
314                if (
315                    not ancestor
316                    or column.table
317                    or isinstance(ancestor, exp.Select)
318                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
319                    or (
320                        isinstance(ancestor, (exp.Order, exp.Distinct))
321                        and (
322                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
323                            or not isinstance(ancestor.parent, exp.Select)
324                            or column.name not in named_selects
325                        )
326                    )
327                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
328                ):
329                    self._columns.append(column)
330
331        return self._columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

table_columns
333    @property
334    def table_columns(self):
335        if self._table_columns is None:
336            self._ensure_collected()
337
338        return self._table_columns
selected_sources
340    @property
341    def selected_sources(self):
342        """
343        Mapping of nodes and sources that are actually selected from in this scope.
344
345        That is, all tables in a schema are selectable at any point. But a
346        table only becomes a selected source if it's included in a FROM or JOIN clause.
347
348        Returns:
349            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
350        """
351        if self._selected_sources is None:
352            result = {}
353
354            for name, node in self.references:
355                if name in self._semi_anti_join_tables:
356                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
357                    # selected source
358                    continue
359
360                if name in result:
361                    raise OptimizeError(f"Alias already used: {name}")
362                if name in self.sources:
363                    result[name] = (node, self.sources[name])
364
365            self._selected_sources = result
366        return self._selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
368    @property
369    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
370        if self._references is None:
371            self._references = []
372
373            for table in self.tables:
374                self._references.append((table.alias_or_name, table))
375            for expression in itertools.chain(self.derived_tables, self.udtfs):
376                self._references.append(
377                    (
378                        _get_source_alias(expression),
379                        expression if expression.args.get("pivots") else expression.unnest(),
380                    )
381                )
382
383        return self._references
external_columns
385    @property
386    def external_columns(self):
387        """
388        Columns that appear to reference sources in outer scopes.
389
390        Returns:
391            list[exp.Column]: Column instances that don't reference sources in the current scope.
392        """
393        if self._external_columns is None:
394            if isinstance(self.expression, exp.SetOperation):
395                left, right = self.union_scopes
396                self._external_columns = left.external_columns + right.external_columns
397            else:
398                self._external_columns = [
399                    c
400                    for c in self.columns
401                    if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables
402                ]
403
404        return self._external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

local_columns
406    @property
407    def local_columns(self):
408        """
409        Columns in this scope that are not external.
410
411        Returns:
412            list[exp.Column]: Column instances that reference sources in the current scope.
413        """
414        if self._local_columns is None:
415            external_columns = set(self.external_columns)
416            self._local_columns = [c for c in self.columns if c not in external_columns]
417
418        return self._local_columns

Columns in this scope that are not external.

Returns:

list[exp.Column]: Column instances that reference sources in the current scope.

unqualified_columns
420    @property
421    def unqualified_columns(self):
422        """
423        Unqualified columns in the current scope.
424
425        Returns:
426             list[exp.Column]: Unqualified columns
427        """
428        return [c for c in self.columns if not c.table]

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints
430    @property
431    def join_hints(self):
432        """
433        Hints that exist in the scope that reference tables
434
435        Returns:
436            list[exp.JoinHint]: Join hints that are referenced within the scope
437        """
438        if self._join_hints is None:
439            return []
440        return self._join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
442    @property
443    def pivots(self):
444        if not self._pivots:
445            self._pivots = [
446                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
447            ]
448
449        return self._pivots
semi_or_anti_join_tables
451    @property
452    def semi_or_anti_join_tables(self):
453        return self._semi_anti_join_tables or set()
def source_columns(self, source_name):
455    def source_columns(self, source_name):
456        """
457        Get all columns in the current scope for a particular source.
458
459        Args:
460            source_name (str): Name of the source
461        Returns:
462            list[exp.Column]: Column instances that reference `source_name`
463        """
464        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery
466    @property
467    def is_subquery(self):
468        """Determine if this scope is a subquery"""
469        return self.scope_type == ScopeType.SUBQUERY

Determine if this scope is a subquery

is_derived_table
471    @property
472    def is_derived_table(self):
473        """Determine if this scope is a derived table"""
474        return self.scope_type == ScopeType.DERIVED_TABLE

Determine if this scope is a derived table

is_union
476    @property
477    def is_union(self):
478        """Determine if this scope is a union"""
479        return self.scope_type == ScopeType.UNION

Determine if this scope is a union

is_cte
481    @property
482    def is_cte(self):
483        """Determine if this scope is a common table expression"""
484        return self.scope_type == ScopeType.CTE

Determine if this scope is a common table expression

is_root
486    @property
487    def is_root(self):
488        """Determine if this is the root scope"""
489        return self.scope_type == ScopeType.ROOT

Determine if this is the root scope

is_udtf
491    @property
492    def is_udtf(self):
493        """Determine if this scope is a UDTF (User Defined Table Function)"""
494        return self.scope_type == ScopeType.UDTF

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery
496    @property
497    def is_correlated_subquery(self):
498        """Determine if this scope is a correlated subquery"""
499        return bool(self.can_be_correlated and self.external_columns)

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
501    def rename_source(self, old_name, new_name):
502        """Rename a source in this scope"""
503        old_name = old_name or ""
504        if old_name in self.sources:
505            self.sources[new_name] = self.sources.pop(old_name)

Rename a source in this scope

def add_source(self, name, source):
507    def add_source(self, name, source):
508        """Add a source to this scope"""
509        self.sources[name] = source
510        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
512    def remove_source(self, name):
513        """Remove a source from this scope"""
514        self.sources.pop(name, None)
515        self.clear_cache()

Remove a source from this scope

def traverse(self):
520    def traverse(self):
521        """
522        Traverse the scope tree from this node.
523
524        Yields:
525            Scope: scope instances in depth-first-search post-order
526        """
527        stack = [self]
528        result = []
529        while stack:
530            scope = stack.pop()
531            result.append(scope)
532            stack.extend(
533                itertools.chain(
534                    scope.cte_scopes,
535                    scope.union_scopes,
536                    scope.table_scopes,
537                    scope.subquery_scopes,
538                )
539            )
540
541        yield from reversed(result)

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
543    def ref_count(self):
544        """
545        Count the number of times each scope in this tree is referenced.
546
547        Returns:
548            dict[int, int]: Mapping of Scope instance ID to reference count
549        """
550        scope_ref_count = defaultdict(lambda: 0)
551
552        for scope in self.traverse():
553            for _, source in scope.selected_sources.values():
554                scope_ref_count[id(source)] += 1
555
556            for name in scope._semi_anti_join_tables:
557                # semi/anti join sources are not actually selected but we still need to
558                # increment their ref count to avoid them being optimized away
559                if name in scope.sources:
560                    scope_ref_count[id(scope.sources[name])] += 1
561
562        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
565def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
566    """
567    Traverse an expression by its "scopes".
568
569    "Scope" represents the current context of a Select statement.
570
571    This is helpful for optimizing queries, where we need more information than
572    the expression tree itself. For example, we might care about the source
573    names within a subquery. Returns a list because a generator could result in
574    incomplete properties which is confusing.
575
576    Examples:
577        >>> import sqlglot
578        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
579        >>> scopes = traverse_scope(expression)
580        >>> scopes[0].expression.sql(), list(scopes[0].sources)
581        ('SELECT a FROM x', ['x'])
582        >>> scopes[1].expression.sql(), list(scopes[1].sources)
583        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
584
585    Args:
586        expression: Expression to traverse
587
588    Returns:
589        A list of the created scope instances
590    """
591    if isinstance(expression, TRAVERSABLES):
592        return list(_traverse_scope(Scope(expression)))
593    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression: Expression to traverse
Returns:

A list of the created scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
596def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
597    """
598    Build a scope tree.
599
600    Args:
601        expression: Expression to build the scope tree for.
602
603    Returns:
604        The root scope
605    """
606    return seq_get(traverse_scope(expression), -1)

Build a scope tree.

Arguments:
  • expression: Expression to build the scope tree for.
Returns:

The root scope

def walk_in_scope(expression, bfs=True, prune=None):
871def walk_in_scope(expression, bfs=True, prune=None):
872    """
873    Returns a generator object which visits all nodes in the syntrax tree, stopping at
874    nodes that start child scopes.
875
876    Args:
877        expression (exp.Expression):
878        bfs (bool): if set to True the BFS traversal order will be applied,
879            otherwise the DFS traversal will be used instead.
880        prune ((node, parent, arg_key) -> bool): callable that returns True if
881            the generator should stop traversing this branch of the tree.
882
883    Yields:
884        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
885    """
886    # We'll use this variable to pass state into the dfs generator.
887    # Whenever we set it to True, we exclude a subtree from traversal.
888    crossed_scope_boundary = False
889
890    for node in expression.walk(
891        bfs=bfs, prune=lambda n: bool(crossed_scope_boundary or (prune and prune(n)))
892    ):
893        crossed_scope_boundary = False
894
895        yield node
896
897        if node is expression:
898            continue
899
900        node_type = type(node)
901        parent_type = type(node.parent)
902        if (
903            node_type is exp.CTE
904            or (parent_type in (exp.From, exp.Join) and _is_derived_table(node))
905            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
906            or isinstance(node, exp.UNWRAPPED_QUERIES)
907        ):
908            crossed_scope_boundary = True
909
910            if node_type is exp.Subquery or isinstance(node, exp.UDTF):
911                # The following args are not actually in the inner scope, so we should visit them
912                for key in ("joins", "laterals", "pivots"):
913                    for arg in node.args.get(key) or []:
914                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
917def find_all_in_scope(expression, expression_types, bfs=True):
918    """
919    Returns a generator object which visits all nodes in this scope and only yields those that
920    match at least one of the specified expression types.
921
922    This does NOT traverse into subscopes.
923
924    Args:
925        expression (exp.Expression):
926        expression_types (tuple[type]|type): the expression type(s) to match.
927        bfs (bool): True to use breadth-first search, False to use depth-first.
928
929    Yields:
930        exp.Expression: nodes
931    """
932    for expression in walk_in_scope(expression, bfs=bfs):
933        if isinstance(expression, tuple(ensure_collection(expression_types))):
934            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
937def find_in_scope(expression, expression_types, bfs=True):
938    """
939    Returns the first node in this scope which matches at least one of the specified types.
940
941    This does NOT traverse into subscopes.
942
943    Args:
944        expression (exp.Expression):
945        expression_types (tuple[type]|type): the expression type(s) to match.
946        bfs (bool): True to use breadth-first search, False to use depth-first.
947
948    Returns:
949        exp.Expression: the node which matches the criteria or None if no node matching
950        the criteria was found.
951    """
952    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.