Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name, seq_get
 12
 13logger = logging.getLogger("sqlglot")
 14
 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
 16
 17
 18class ScopeType(Enum):
 19    ROOT = auto()
 20    SUBQUERY = auto()
 21    DERIVED_TABLE = auto()
 22    CTE = auto()
 23    UNION = auto()
 24    UDTF = auto()
 25
 26
 27class Scope:
 28    """
 29    Selection scope.
 30
 31    Attributes:
 32        expression (exp.Select|exp.SetOperation): Root expression of this scope
 33        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 34            a Table expression or another Scope instance. For example:
 35                SELECT * FROM x                     {"x": Table(this="x")}
 36                SELECT * FROM x AS y                {"y": Table(this="x")}
 37                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 38        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 39            For example:
 40                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 41            The LATERAL VIEW EXPLODE gets x as a source.
 42        cte_sources (dict[str, Scope]): Sources from CTES
 43        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 44            defines a column list for the alias of this scope, this is that list of columns.
 45            For example:
 46                SELECT * FROM (SELECT ...) AS y(col1, col2)
 47            The inner query would have `["col1", "col2"]` for its `outer_columns`
 48        parent (Scope): Parent scope
 49        scope_type (ScopeType): Type of this scope, relative to it's parent
 50        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 51        cte_scopes (list[Scope]): List of all child scopes for CTEs
 52        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 53        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 54        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 55        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 56            a list of the left and right child scopes.
 57    """
 58
 59    def __init__(
 60        self,
 61        expression,
 62        sources=None,
 63        outer_columns=None,
 64        parent=None,
 65        scope_type=ScopeType.ROOT,
 66        lateral_sources=None,
 67        cte_sources=None,
 68        can_be_correlated=None,
 69    ):
 70        self.expression = expression
 71        self.sources = sources or {}
 72        self.lateral_sources = lateral_sources or {}
 73        self.cte_sources = cte_sources or {}
 74        self.sources.update(self.lateral_sources)
 75        self.sources.update(self.cte_sources)
 76        self.outer_columns = outer_columns or []
 77        self.parent = parent
 78        self.scope_type = scope_type
 79        self.subquery_scopes = []
 80        self.derived_table_scopes = []
 81        self.table_scopes = []
 82        self.cte_scopes = []
 83        self.union_scopes = []
 84        self.udtf_scopes = []
 85        self.can_be_correlated = can_be_correlated
 86        self.clear_cache()
 87
 88    def clear_cache(self):
 89        self._collected = False
 90        self._raw_columns = None
 91        self._table_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._local_columns = None
102        self._join_hints = None
103        self._pivots = None
104        self._references = None
105        self._semi_anti_join_tables = None
106        self._column_index = None
107
108    def branch(
109        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
110    ):
111        """Branch from the current scope to a new, inner scope"""
112        return Scope(
113            expression=expression.unnest(),
114            sources=sources.copy() if sources else None,
115            parent=self,
116            scope_type=scope_type,
117            cte_sources={**self.cte_sources, **(cte_sources or {})},
118            lateral_sources=lateral_sources.copy() if lateral_sources else None,
119            can_be_correlated=self.can_be_correlated
120            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
121            **kwargs,
122        )
123
124    def _collect(self):
125        self._tables = []
126        self._ctes = []
127        self._subqueries = []
128        self._derived_tables = []
129        self._udtfs = []
130        self._raw_columns = []
131        self._table_columns = []
132        self._stars = []
133        self._join_hints = []
134        self._semi_anti_join_tables = set()
135        self._column_index = set()
136
137        for node in self.walk(bfs=False):
138            if node is self.expression:
139                continue
140
141            if isinstance(node, exp.Dot) and node.is_star:
142                self._stars.append(node)
143            elif isinstance(node, exp.Column) and not isinstance(node, exp.Pseudocolumn):
144                self._column_index.add(id(node))
145
146                if isinstance(node.this, exp.Star):
147                    self._stars.append(node)
148                else:
149                    self._raw_columns.append(node)
150            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
151                parent = node.parent
152                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
153                    self._semi_anti_join_tables.add(node.alias_or_name)
154
155                self._tables.append(node)
156            elif isinstance(node, exp.JoinHint):
157                self._join_hints.append(node)
158            elif isinstance(node, exp.UDTF):
159                self._udtfs.append(node)
160            elif isinstance(node, exp.CTE):
161                self._ctes.append(node)
162            elif _is_derived_table(node) and _is_from_or_join(node):
163                self._derived_tables.append(node)
164            elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node):
165                self._subqueries.append(node)
166            elif isinstance(node, exp.TableColumn):
167                self._table_columns.append(node)
168
169        self._collected = True
170
171    def _ensure_collected(self):
172        if not self._collected:
173            self._collect()
174
175    def walk(self, bfs=True, prune=None):
176        return walk_in_scope(self.expression, bfs=bfs, prune=None)
177
178    def find(self, *expression_types, bfs=True):
179        return find_in_scope(self.expression, expression_types, bfs=bfs)
180
181    def find_all(self, *expression_types, bfs=True):
182        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
183
184    def replace(self, old, new):
185        """
186        Replace `old` with `new`.
187
188        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
189
190        Args:
191            old (exp.Expression): old node
192            new (exp.Expression): new node
193        """
194        old.replace(new)
195        self.clear_cache()
196
197    @property
198    def tables(self):
199        """
200        List of tables in this scope.
201
202        Returns:
203            list[exp.Table]: tables
204        """
205        self._ensure_collected()
206        return self._tables
207
208    @property
209    def ctes(self):
210        """
211        List of CTEs in this scope.
212
213        Returns:
214            list[exp.CTE]: ctes
215        """
216        self._ensure_collected()
217        return self._ctes
218
219    @property
220    def derived_tables(self):
221        """
222        List of derived tables in this scope.
223
224        For example:
225            SELECT * FROM (SELECT ...) <- that's a derived table
226
227        Returns:
228            list[exp.Subquery]: derived tables
229        """
230        self._ensure_collected()
231        return self._derived_tables
232
233    @property
234    def udtfs(self):
235        """
236        List of "User Defined Tabular Functions" in this scope.
237
238        Returns:
239            list[exp.UDTF]: UDTFs
240        """
241        self._ensure_collected()
242        return self._udtfs
243
244    @property
245    def subqueries(self):
246        """
247        List of subqueries in this scope.
248
249        For example:
250            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
251
252        Returns:
253            list[exp.Select | exp.SetOperation]: subqueries
254        """
255        self._ensure_collected()
256        return self._subqueries
257
258    @property
259    def stars(self) -> t.List[exp.Column | exp.Dot]:
260        """
261        List of star expressions (columns or dots) in this scope.
262        """
263        self._ensure_collected()
264        return self._stars
265
266    @property
267    def column_index(self) -> t.Set[int]:
268        """
269        Set of column object IDs that belong to this scope's expression.
270        """
271        self._ensure_collected()
272        return self._column_index
273
274    @property
275    def columns(self):
276        """
277        List of columns in this scope.
278
279        Returns:
280            list[exp.Column]: Column instances in this scope, plus any
281                Columns that reference this scope from correlated subqueries.
282        """
283        if self._columns is None:
284            self._ensure_collected()
285            columns = self._raw_columns
286
287            external_columns = [
288                column
289                for scope in itertools.chain(
290                    self.subquery_scopes,
291                    self.udtf_scopes,
292                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
293                )
294                for column in scope.external_columns
295            ]
296
297            named_selects = set(self.expression.named_selects)
298
299            self._columns = []
300            for column in columns + external_columns:
301                ancestor = column.find_ancestor(
302                    exp.Select,
303                    exp.Qualify,
304                    exp.Order,
305                    exp.Having,
306                    exp.Hint,
307                    exp.Table,
308                    exp.Star,
309                    exp.Distinct,
310                )
311                if (
312                    not ancestor
313                    or column.table
314                    or isinstance(ancestor, exp.Select)
315                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
316                    or (
317                        isinstance(ancestor, (exp.Order, exp.Distinct))
318                        and (
319                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
320                            or not isinstance(ancestor.parent, exp.Select)
321                            or column.name not in named_selects
322                        )
323                    )
324                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
325                ):
326                    self._columns.append(column)
327
328        return self._columns
329
330    @property
331    def table_columns(self):
332        if self._table_columns is None:
333            self._ensure_collected()
334
335        return self._table_columns
336
337    @property
338    def selected_sources(self):
339        """
340        Mapping of nodes and sources that are actually selected from in this scope.
341
342        That is, all tables in a schema are selectable at any point. But a
343        table only becomes a selected source if it's included in a FROM or JOIN clause.
344
345        Returns:
346            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
347        """
348        if self._selected_sources is None:
349            result = {}
350
351            for name, node in self.references:
352                if name in self._semi_anti_join_tables:
353                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
354                    # selected source
355                    continue
356
357                if name in result:
358                    raise OptimizeError(f"Alias already used: {name}")
359                if name in self.sources:
360                    result[name] = (node, self.sources[name])
361
362            self._selected_sources = result
363        return self._selected_sources
364
365    @property
366    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
367        if self._references is None:
368            self._references = []
369
370            for table in self.tables:
371                self._references.append((table.alias_or_name, table))
372            for expression in itertools.chain(self.derived_tables, self.udtfs):
373                self._references.append(
374                    (
375                        _get_source_alias(expression),
376                        expression if expression.args.get("pivots") else expression.unnest(),
377                    )
378                )
379
380        return self._references
381
382    @property
383    def external_columns(self):
384        """
385        Columns that appear to reference sources in outer scopes.
386
387        Returns:
388            list[exp.Column]: Column instances that don't reference sources in the current scope.
389        """
390        if self._external_columns is None:
391            if isinstance(self.expression, exp.SetOperation):
392                left, right = self.union_scopes
393                self._external_columns = left.external_columns + right.external_columns
394            else:
395                self._external_columns = [
396                    c
397                    for c in self.columns
398                    if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables
399                ]
400
401        return self._external_columns
402
403    @property
404    def local_columns(self):
405        """
406        Columns in this scope that are not external.
407
408        Returns:
409            list[exp.Column]: Column instances that reference sources in the current scope.
410        """
411        if self._local_columns is None:
412            external_columns = set(self.external_columns)
413            self._local_columns = [c for c in self.columns if c not in external_columns]
414
415        return self._local_columns
416
417    @property
418    def unqualified_columns(self):
419        """
420        Unqualified columns in the current scope.
421
422        Returns:
423             list[exp.Column]: Unqualified columns
424        """
425        return [c for c in self.columns if not c.table]
426
427    @property
428    def join_hints(self):
429        """
430        Hints that exist in the scope that reference tables
431
432        Returns:
433            list[exp.JoinHint]: Join hints that are referenced within the scope
434        """
435        if self._join_hints is None:
436            return []
437        return self._join_hints
438
439    @property
440    def pivots(self):
441        if not self._pivots:
442            self._pivots = [
443                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
444            ]
445
446        return self._pivots
447
448    @property
449    def semi_or_anti_join_tables(self):
450        return self._semi_anti_join_tables or set()
451
452    def source_columns(self, source_name):
453        """
454        Get all columns in the current scope for a particular source.
455
456        Args:
457            source_name (str): Name of the source
458        Returns:
459            list[exp.Column]: Column instances that reference `source_name`
460        """
461        return [column for column in self.columns if column.table == source_name]
462
463    @property
464    def is_subquery(self):
465        """Determine if this scope is a subquery"""
466        return self.scope_type == ScopeType.SUBQUERY
467
468    @property
469    def is_derived_table(self):
470        """Determine if this scope is a derived table"""
471        return self.scope_type == ScopeType.DERIVED_TABLE
472
473    @property
474    def is_union(self):
475        """Determine if this scope is a union"""
476        return self.scope_type == ScopeType.UNION
477
478    @property
479    def is_cte(self):
480        """Determine if this scope is a common table expression"""
481        return self.scope_type == ScopeType.CTE
482
483    @property
484    def is_root(self):
485        """Determine if this is the root scope"""
486        return self.scope_type == ScopeType.ROOT
487
488    @property
489    def is_udtf(self):
490        """Determine if this scope is a UDTF (User Defined Table Function)"""
491        return self.scope_type == ScopeType.UDTF
492
493    @property
494    def is_correlated_subquery(self):
495        """Determine if this scope is a correlated subquery"""
496        return bool(self.can_be_correlated and self.external_columns)
497
498    def rename_source(self, old_name, new_name):
499        """Rename a source in this scope"""
500        old_name = old_name or ""
501        if old_name in self.sources:
502            self.sources[new_name] = self.sources.pop(old_name)
503
504    def add_source(self, name, source):
505        """Add a source to this scope"""
506        self.sources[name] = source
507        self.clear_cache()
508
509    def remove_source(self, name):
510        """Remove a source from this scope"""
511        self.sources.pop(name, None)
512        self.clear_cache()
513
514    def __repr__(self):
515        return f"Scope<{self.expression.sql()}>"
516
517    def traverse(self):
518        """
519        Traverse the scope tree from this node.
520
521        Yields:
522            Scope: scope instances in depth-first-search post-order
523        """
524        stack = [self]
525        result = []
526        while stack:
527            scope = stack.pop()
528            result.append(scope)
529            stack.extend(
530                itertools.chain(
531                    scope.cte_scopes,
532                    scope.union_scopes,
533                    scope.table_scopes,
534                    scope.subquery_scopes,
535                )
536            )
537
538        yield from reversed(result)
539
540    def ref_count(self):
541        """
542        Count the number of times each scope in this tree is referenced.
543
544        Returns:
545            dict[int, int]: Mapping of Scope instance ID to reference count
546        """
547        scope_ref_count = defaultdict(lambda: 0)
548
549        for scope in self.traverse():
550            for _, source in scope.selected_sources.values():
551                scope_ref_count[id(source)] += 1
552
553            for name in scope._semi_anti_join_tables:
554                # semi/anti join sources are not actually selected but we still need to
555                # increment their ref count to avoid them being optimized away
556                if name in scope.sources:
557                    scope_ref_count[id(scope.sources[name])] += 1
558
559        return scope_ref_count
560
561
562def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
563    """
564    Traverse an expression by its "scopes".
565
566    "Scope" represents the current context of a Select statement.
567
568    This is helpful for optimizing queries, where we need more information than
569    the expression tree itself. For example, we might care about the source
570    names within a subquery. Returns a list because a generator could result in
571    incomplete properties which is confusing.
572
573    Examples:
574        >>> import sqlglot
575        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
576        >>> scopes = traverse_scope(expression)
577        >>> scopes[0].expression.sql(), list(scopes[0].sources)
578        ('SELECT a FROM x', ['x'])
579        >>> scopes[1].expression.sql(), list(scopes[1].sources)
580        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
581
582    Args:
583        expression: Expression to traverse
584
585    Returns:
586        A list of the created scope instances
587    """
588    if isinstance(expression, TRAVERSABLES):
589        return list(_traverse_scope(Scope(expression)))
590    return []
591
592
593def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
594    """
595    Build a scope tree.
596
597    Args:
598        expression: Expression to build the scope tree for.
599
600    Returns:
601        The root scope
602    """
603    return seq_get(traverse_scope(expression), -1)
604
605
606def _traverse_scope(scope):
607    expression = scope.expression
608
609    if isinstance(expression, exp.Select):
610        yield from _traverse_select(scope)
611    elif isinstance(expression, exp.SetOperation):
612        yield from _traverse_ctes(scope)
613        yield from _traverse_union(scope)
614        return
615    elif isinstance(expression, exp.Subquery):
616        if scope.is_root:
617            yield from _traverse_select(scope)
618        else:
619            yield from _traverse_subqueries(scope)
620    elif isinstance(expression, exp.Table):
621        yield from _traverse_tables(scope)
622    elif isinstance(expression, exp.UDTF):
623        yield from _traverse_udtfs(scope)
624    elif isinstance(expression, exp.DDL):
625        if isinstance(expression.expression, exp.Query):
626            yield from _traverse_ctes(scope)
627            yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
628        return
629    elif isinstance(expression, exp.DML):
630        yield from _traverse_ctes(scope)
631        for query in find_all_in_scope(expression, exp.Query):
632            # This check ensures we don't yield the CTE/nested queries twice
633            if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
634                yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
635        return
636    else:
637        logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
638        return
639
640    yield scope
641
642
643def _traverse_select(scope):
644    yield from _traverse_ctes(scope)
645    yield from _traverse_tables(scope)
646    yield from _traverse_subqueries(scope)
647
648
649def _traverse_union(scope):
650    prev_scope = None
651    union_scope_stack = [scope]
652    expression_stack = [scope.expression.right, scope.expression.left]
653
654    while expression_stack:
655        expression = expression_stack.pop()
656        union_scope = union_scope_stack[-1]
657
658        new_scope = union_scope.branch(
659            expression,
660            outer_columns=union_scope.outer_columns,
661            scope_type=ScopeType.UNION,
662        )
663
664        if isinstance(expression, exp.SetOperation):
665            yield from _traverse_ctes(new_scope)
666
667            union_scope_stack.append(new_scope)
668            expression_stack.extend([expression.right, expression.left])
669            continue
670
671        for scope in _traverse_scope(new_scope):
672            yield scope
673
674        if prev_scope:
675            union_scope_stack.pop()
676            union_scope.union_scopes = [prev_scope, scope]
677            prev_scope = union_scope
678
679            yield union_scope
680        else:
681            prev_scope = scope
682
683
684def _traverse_ctes(scope):
685    sources = {}
686
687    for cte in scope.ctes:
688        cte_name = cte.alias
689
690        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
691        # thus the recursive scope is the first section of the union.
692        with_ = scope.expression.args.get("with_")
693        if with_ and with_.recursive:
694            union = cte.this
695
696            if isinstance(union, exp.SetOperation):
697                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
698
699        child_scope = None
700
701        for child_scope in _traverse_scope(
702            scope.branch(
703                cte.this,
704                cte_sources=sources,
705                outer_columns=cte.alias_column_names,
706                scope_type=ScopeType.CTE,
707            )
708        ):
709            yield child_scope
710
711        # append the final child_scope yielded
712        if child_scope:
713            sources[cte_name] = child_scope
714            scope.cte_scopes.append(child_scope)
715
716    scope.sources.update(sources)
717    scope.cte_sources.update(sources)
718
719
720def _is_derived_table(expression: exp.Subquery) -> bool:
721    """
722    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
723    as it doesn't introduce a new scope. If an alias is present, it shadows all names
724    under the Subquery, so that's one exception to this rule.
725    """
726    return isinstance(expression, exp.Subquery) and bool(
727        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
728    )
729
730
731def _is_from_or_join(expression: exp.Expression) -> bool:
732    """
733    Determine if `expression` is the FROM or JOIN clause of a SELECT statement.
734    """
735    parent = expression.parent
736
737    # Subqueries can be arbitrarily nested
738    while isinstance(parent, exp.Subquery):
739        parent = parent.parent
740
741    return isinstance(parent, (exp.From, exp.Join))
742
743
744def _traverse_tables(scope):
745    sources = {}
746
747    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
748    expressions = []
749    from_ = scope.expression.args.get("from_")
750    if from_:
751        expressions.append(from_.this)
752
753    for join in scope.expression.args.get("joins") or []:
754        expressions.append(join.this)
755
756    if isinstance(scope.expression, exp.Table):
757        expressions.append(scope.expression)
758
759    expressions.extend(scope.expression.args.get("laterals") or [])
760
761    for expression in expressions:
762        if isinstance(expression, exp.Final):
763            expression = expression.this
764        if isinstance(expression, exp.Table):
765            table_name = expression.name
766            source_name = expression.alias_or_name
767
768            if table_name in scope.sources and not expression.db:
769                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
770                # it is pivoted, because then we get back a new table and hence a new source.
771                pivots = expression.args.get("pivots")
772                if pivots:
773                    sources[pivots[0].alias] = expression
774                else:
775                    sources[source_name] = scope.sources[table_name]
776            elif source_name in sources:
777                sources[find_new_name(sources, table_name)] = expression
778            else:
779                sources[source_name] = expression
780
781            # Make sure to not include the joins twice
782            if expression is not scope.expression:
783                expressions.extend(join.this for join in expression.args.get("joins") or [])
784
785            continue
786
787        if not isinstance(expression, exp.DerivedTable):
788            continue
789
790        if isinstance(expression, exp.UDTF):
791            lateral_sources = sources
792            scope_type = ScopeType.UDTF
793            scopes = scope.udtf_scopes
794        elif _is_derived_table(expression):
795            lateral_sources = None
796            scope_type = ScopeType.DERIVED_TABLE
797            scopes = scope.derived_table_scopes
798            expressions.extend(join.this for join in expression.args.get("joins") or [])
799        else:
800            # Makes sure we check for possible sources in nested table constructs
801            expressions.append(expression.this)
802            expressions.extend(join.this for join in expression.args.get("joins") or [])
803            continue
804
805        child_scope = None
806
807        for child_scope in _traverse_scope(
808            scope.branch(
809                expression,
810                lateral_sources=lateral_sources,
811                outer_columns=expression.alias_column_names,
812                scope_type=scope_type,
813            )
814        ):
815            yield child_scope
816
817            # Tables without aliases will be set as ""
818            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
819            # Until then, this means that only a single, unaliased derived table is allowed (rather,
820            # the latest one wins.
821            sources[_get_source_alias(expression)] = child_scope
822
823        # append the final child_scope yielded
824        if child_scope:
825            scopes.append(child_scope)
826            scope.table_scopes.append(child_scope)
827
828    scope.sources.update(sources)
829
830
831def _traverse_subqueries(scope):
832    for subquery in scope.subqueries:
833        top = None
834        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
835            yield child_scope
836            top = child_scope
837        scope.subquery_scopes.append(top)
838
839
840def _traverse_udtfs(scope):
841    if isinstance(scope.expression, exp.Unnest):
842        expressions = scope.expression.expressions
843    elif isinstance(scope.expression, exp.Lateral):
844        expressions = [scope.expression.this]
845    else:
846        expressions = []
847
848    sources = {}
849    for expression in expressions:
850        if isinstance(expression, exp.Subquery):
851            top = None
852            for child_scope in _traverse_scope(
853                scope.branch(
854                    expression,
855                    scope_type=ScopeType.SUBQUERY,
856                    outer_columns=expression.alias_column_names,
857                )
858            ):
859                yield child_scope
860                top = child_scope
861                sources[_get_source_alias(expression)] = child_scope
862
863            scope.subquery_scopes.append(top)
864
865    scope.sources.update(sources)
866
867
868def walk_in_scope(expression, bfs=True, prune=None):
869    """
870    Returns a generator object which visits all nodes in the syntrax tree, stopping at
871    nodes that start child scopes.
872
873    Args:
874        expression (exp.Expression):
875        bfs (bool): if set to True the BFS traversal order will be applied,
876            otherwise the DFS traversal will be used instead.
877        prune ((node, parent, arg_key) -> bool): callable that returns True if
878            the generator should stop traversing this branch of the tree.
879
880    Yields:
881        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
882    """
883    # We'll use this variable to pass state into the dfs generator.
884    # Whenever we set it to True, we exclude a subtree from traversal.
885    crossed_scope_boundary = False
886
887    for node in expression.walk(
888        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
889    ):
890        crossed_scope_boundary = False
891
892        yield node
893
894        if node is expression:
895            continue
896
897        if (
898            isinstance(node, exp.CTE)
899            or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node))
900            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
901            or isinstance(node, exp.UNWRAPPED_QUERIES)
902        ):
903            crossed_scope_boundary = True
904
905            if isinstance(node, (exp.Subquery, exp.UDTF)):
906                # The following args are not actually in the inner scope, so we should visit them
907                for key in ("joins", "laterals", "pivots"):
908                    for arg in node.args.get(key) or []:
909                        yield from walk_in_scope(arg, bfs=bfs)
910
911
912def find_all_in_scope(expression, expression_types, bfs=True):
913    """
914    Returns a generator object which visits all nodes in this scope and only yields those that
915    match at least one of the specified expression types.
916
917    This does NOT traverse into subscopes.
918
919    Args:
920        expression (exp.Expression):
921        expression_types (tuple[type]|type): the expression type(s) to match.
922        bfs (bool): True to use breadth-first search, False to use depth-first.
923
924    Yields:
925        exp.Expression: nodes
926    """
927    for expression in walk_in_scope(expression, bfs=bfs):
928        if isinstance(expression, tuple(ensure_collection(expression_types))):
929            yield expression
930
931
932def find_in_scope(expression, expression_types, bfs=True):
933    """
934    Returns the first node in this scope which matches at least one of the specified types.
935
936    This does NOT traverse into subscopes.
937
938    Args:
939        expression (exp.Expression):
940        expression_types (tuple[type]|type): the expression type(s) to match.
941        bfs (bool): True to use breadth-first search, False to use depth-first.
942
943    Returns:
944        exp.Expression: the node which matches the criteria or None if no node matching
945        the criteria was found.
946    """
947    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
948
949
950def _get_source_alias(expression):
951    alias_arg = expression.args.get("alias")
952    alias_name = expression.alias
953
954    if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1:
955        alias_name = alias_arg.columns[0].name
956
957    return alias_name
logger = <Logger sqlglot (WARNING)>
TRAVERSABLES = (<class 'sqlglot.expressions.Query'>, <class 'sqlglot.expressions.DDL'>, <class 'sqlglot.expressions.DML'>)
class ScopeType(enum.Enum):
19class ScopeType(Enum):
20    ROOT = auto()
21    SUBQUERY = auto()
22    DERIVED_TABLE = auto()
23    CTE = auto()
24    UNION = auto()
25    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
class Scope:
 28class Scope:
 29    """
 30    Selection scope.
 31
 32    Attributes:
 33        expression (exp.Select|exp.SetOperation): Root expression of this scope
 34        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 35            a Table expression or another Scope instance. For example:
 36                SELECT * FROM x                     {"x": Table(this="x")}
 37                SELECT * FROM x AS y                {"y": Table(this="x")}
 38                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 39        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 40            For example:
 41                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 42            The LATERAL VIEW EXPLODE gets x as a source.
 43        cte_sources (dict[str, Scope]): Sources from CTES
 44        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 45            defines a column list for the alias of this scope, this is that list of columns.
 46            For example:
 47                SELECT * FROM (SELECT ...) AS y(col1, col2)
 48            The inner query would have `["col1", "col2"]` for its `outer_columns`
 49        parent (Scope): Parent scope
 50        scope_type (ScopeType): Type of this scope, relative to it's parent
 51        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 52        cte_scopes (list[Scope]): List of all child scopes for CTEs
 53        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 54        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 55        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 56        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 57            a list of the left and right child scopes.
 58    """
 59
 60    def __init__(
 61        self,
 62        expression,
 63        sources=None,
 64        outer_columns=None,
 65        parent=None,
 66        scope_type=ScopeType.ROOT,
 67        lateral_sources=None,
 68        cte_sources=None,
 69        can_be_correlated=None,
 70    ):
 71        self.expression = expression
 72        self.sources = sources or {}
 73        self.lateral_sources = lateral_sources or {}
 74        self.cte_sources = cte_sources or {}
 75        self.sources.update(self.lateral_sources)
 76        self.sources.update(self.cte_sources)
 77        self.outer_columns = outer_columns or []
 78        self.parent = parent
 79        self.scope_type = scope_type
 80        self.subquery_scopes = []
 81        self.derived_table_scopes = []
 82        self.table_scopes = []
 83        self.cte_scopes = []
 84        self.union_scopes = []
 85        self.udtf_scopes = []
 86        self.can_be_correlated = can_be_correlated
 87        self.clear_cache()
 88
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._local_columns = None
103        self._join_hints = None
104        self._pivots = None
105        self._references = None
106        self._semi_anti_join_tables = None
107        self._column_index = None
108
109    def branch(
110        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
111    ):
112        """Branch from the current scope to a new, inner scope"""
113        return Scope(
114            expression=expression.unnest(),
115            sources=sources.copy() if sources else None,
116            parent=self,
117            scope_type=scope_type,
118            cte_sources={**self.cte_sources, **(cte_sources or {})},
119            lateral_sources=lateral_sources.copy() if lateral_sources else None,
120            can_be_correlated=self.can_be_correlated
121            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
122            **kwargs,
123        )
124
125    def _collect(self):
126        self._tables = []
127        self._ctes = []
128        self._subqueries = []
129        self._derived_tables = []
130        self._udtfs = []
131        self._raw_columns = []
132        self._table_columns = []
133        self._stars = []
134        self._join_hints = []
135        self._semi_anti_join_tables = set()
136        self._column_index = set()
137
138        for node in self.walk(bfs=False):
139            if node is self.expression:
140                continue
141
142            if isinstance(node, exp.Dot) and node.is_star:
143                self._stars.append(node)
144            elif isinstance(node, exp.Column) and not isinstance(node, exp.Pseudocolumn):
145                self._column_index.add(id(node))
146
147                if isinstance(node.this, exp.Star):
148                    self._stars.append(node)
149                else:
150                    self._raw_columns.append(node)
151            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
152                parent = node.parent
153                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
154                    self._semi_anti_join_tables.add(node.alias_or_name)
155
156                self._tables.append(node)
157            elif isinstance(node, exp.JoinHint):
158                self._join_hints.append(node)
159            elif isinstance(node, exp.UDTF):
160                self._udtfs.append(node)
161            elif isinstance(node, exp.CTE):
162                self._ctes.append(node)
163            elif _is_derived_table(node) and _is_from_or_join(node):
164                self._derived_tables.append(node)
165            elif isinstance(node, exp.UNWRAPPED_QUERIES) and not _is_from_or_join(node):
166                self._subqueries.append(node)
167            elif isinstance(node, exp.TableColumn):
168                self._table_columns.append(node)
169
170        self._collected = True
171
172    def _ensure_collected(self):
173        if not self._collected:
174            self._collect()
175
176    def walk(self, bfs=True, prune=None):
177        return walk_in_scope(self.expression, bfs=bfs, prune=None)
178
179    def find(self, *expression_types, bfs=True):
180        return find_in_scope(self.expression, expression_types, bfs=bfs)
181
182    def find_all(self, *expression_types, bfs=True):
183        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
184
185    def replace(self, old, new):
186        """
187        Replace `old` with `new`.
188
189        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
190
191        Args:
192            old (exp.Expression): old node
193            new (exp.Expression): new node
194        """
195        old.replace(new)
196        self.clear_cache()
197
198    @property
199    def tables(self):
200        """
201        List of tables in this scope.
202
203        Returns:
204            list[exp.Table]: tables
205        """
206        self._ensure_collected()
207        return self._tables
208
209    @property
210    def ctes(self):
211        """
212        List of CTEs in this scope.
213
214        Returns:
215            list[exp.CTE]: ctes
216        """
217        self._ensure_collected()
218        return self._ctes
219
220    @property
221    def derived_tables(self):
222        """
223        List of derived tables in this scope.
224
225        For example:
226            SELECT * FROM (SELECT ...) <- that's a derived table
227
228        Returns:
229            list[exp.Subquery]: derived tables
230        """
231        self._ensure_collected()
232        return self._derived_tables
233
234    @property
235    def udtfs(self):
236        """
237        List of "User Defined Tabular Functions" in this scope.
238
239        Returns:
240            list[exp.UDTF]: UDTFs
241        """
242        self._ensure_collected()
243        return self._udtfs
244
245    @property
246    def subqueries(self):
247        """
248        List of subqueries in this scope.
249
250        For example:
251            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
252
253        Returns:
254            list[exp.Select | exp.SetOperation]: subqueries
255        """
256        self._ensure_collected()
257        return self._subqueries
258
259    @property
260    def stars(self) -> t.List[exp.Column | exp.Dot]:
261        """
262        List of star expressions (columns or dots) in this scope.
263        """
264        self._ensure_collected()
265        return self._stars
266
267    @property
268    def column_index(self) -> t.Set[int]:
269        """
270        Set of column object IDs that belong to this scope's expression.
271        """
272        self._ensure_collected()
273        return self._column_index
274
275    @property
276    def columns(self):
277        """
278        List of columns in this scope.
279
280        Returns:
281            list[exp.Column]: Column instances in this scope, plus any
282                Columns that reference this scope from correlated subqueries.
283        """
284        if self._columns is None:
285            self._ensure_collected()
286            columns = self._raw_columns
287
288            external_columns = [
289                column
290                for scope in itertools.chain(
291                    self.subquery_scopes,
292                    self.udtf_scopes,
293                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
294                )
295                for column in scope.external_columns
296            ]
297
298            named_selects = set(self.expression.named_selects)
299
300            self._columns = []
301            for column in columns + external_columns:
302                ancestor = column.find_ancestor(
303                    exp.Select,
304                    exp.Qualify,
305                    exp.Order,
306                    exp.Having,
307                    exp.Hint,
308                    exp.Table,
309                    exp.Star,
310                    exp.Distinct,
311                )
312                if (
313                    not ancestor
314                    or column.table
315                    or isinstance(ancestor, exp.Select)
316                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
317                    or (
318                        isinstance(ancestor, (exp.Order, exp.Distinct))
319                        and (
320                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
321                            or not isinstance(ancestor.parent, exp.Select)
322                            or column.name not in named_selects
323                        )
324                    )
325                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
326                ):
327                    self._columns.append(column)
328
329        return self._columns
330
331    @property
332    def table_columns(self):
333        if self._table_columns is None:
334            self._ensure_collected()
335
336        return self._table_columns
337
338    @property
339    def selected_sources(self):
340        """
341        Mapping of nodes and sources that are actually selected from in this scope.
342
343        That is, all tables in a schema are selectable at any point. But a
344        table only becomes a selected source if it's included in a FROM or JOIN clause.
345
346        Returns:
347            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
348        """
349        if self._selected_sources is None:
350            result = {}
351
352            for name, node in self.references:
353                if name in self._semi_anti_join_tables:
354                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
355                    # selected source
356                    continue
357
358                if name in result:
359                    raise OptimizeError(f"Alias already used: {name}")
360                if name in self.sources:
361                    result[name] = (node, self.sources[name])
362
363            self._selected_sources = result
364        return self._selected_sources
365
366    @property
367    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
368        if self._references is None:
369            self._references = []
370
371            for table in self.tables:
372                self._references.append((table.alias_or_name, table))
373            for expression in itertools.chain(self.derived_tables, self.udtfs):
374                self._references.append(
375                    (
376                        _get_source_alias(expression),
377                        expression if expression.args.get("pivots") else expression.unnest(),
378                    )
379                )
380
381        return self._references
382
383    @property
384    def external_columns(self):
385        """
386        Columns that appear to reference sources in outer scopes.
387
388        Returns:
389            list[exp.Column]: Column instances that don't reference sources in the current scope.
390        """
391        if self._external_columns is None:
392            if isinstance(self.expression, exp.SetOperation):
393                left, right = self.union_scopes
394                self._external_columns = left.external_columns + right.external_columns
395            else:
396                self._external_columns = [
397                    c
398                    for c in self.columns
399                    if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables
400                ]
401
402        return self._external_columns
403
404    @property
405    def local_columns(self):
406        """
407        Columns in this scope that are not external.
408
409        Returns:
410            list[exp.Column]: Column instances that reference sources in the current scope.
411        """
412        if self._local_columns is None:
413            external_columns = set(self.external_columns)
414            self._local_columns = [c for c in self.columns if c not in external_columns]
415
416        return self._local_columns
417
418    @property
419    def unqualified_columns(self):
420        """
421        Unqualified columns in the current scope.
422
423        Returns:
424             list[exp.Column]: Unqualified columns
425        """
426        return [c for c in self.columns if not c.table]
427
428    @property
429    def join_hints(self):
430        """
431        Hints that exist in the scope that reference tables
432
433        Returns:
434            list[exp.JoinHint]: Join hints that are referenced within the scope
435        """
436        if self._join_hints is None:
437            return []
438        return self._join_hints
439
440    @property
441    def pivots(self):
442        if not self._pivots:
443            self._pivots = [
444                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
445            ]
446
447        return self._pivots
448
449    @property
450    def semi_or_anti_join_tables(self):
451        return self._semi_anti_join_tables or set()
452
453    def source_columns(self, source_name):
454        """
455        Get all columns in the current scope for a particular source.
456
457        Args:
458            source_name (str): Name of the source
459        Returns:
460            list[exp.Column]: Column instances that reference `source_name`
461        """
462        return [column for column in self.columns if column.table == source_name]
463
464    @property
465    def is_subquery(self):
466        """Determine if this scope is a subquery"""
467        return self.scope_type == ScopeType.SUBQUERY
468
469    @property
470    def is_derived_table(self):
471        """Determine if this scope is a derived table"""
472        return self.scope_type == ScopeType.DERIVED_TABLE
473
474    @property
475    def is_union(self):
476        """Determine if this scope is a union"""
477        return self.scope_type == ScopeType.UNION
478
479    @property
480    def is_cte(self):
481        """Determine if this scope is a common table expression"""
482        return self.scope_type == ScopeType.CTE
483
484    @property
485    def is_root(self):
486        """Determine if this is the root scope"""
487        return self.scope_type == ScopeType.ROOT
488
489    @property
490    def is_udtf(self):
491        """Determine if this scope is a UDTF (User Defined Table Function)"""
492        return self.scope_type == ScopeType.UDTF
493
494    @property
495    def is_correlated_subquery(self):
496        """Determine if this scope is a correlated subquery"""
497        return bool(self.can_be_correlated and self.external_columns)
498
499    def rename_source(self, old_name, new_name):
500        """Rename a source in this scope"""
501        old_name = old_name or ""
502        if old_name in self.sources:
503            self.sources[new_name] = self.sources.pop(old_name)
504
505    def add_source(self, name, source):
506        """Add a source to this scope"""
507        self.sources[name] = source
508        self.clear_cache()
509
510    def remove_source(self, name):
511        """Remove a source from this scope"""
512        self.sources.pop(name, None)
513        self.clear_cache()
514
515    def __repr__(self):
516        return f"Scope<{self.expression.sql()}>"
517
518    def traverse(self):
519        """
520        Traverse the scope tree from this node.
521
522        Yields:
523            Scope: scope instances in depth-first-search post-order
524        """
525        stack = [self]
526        result = []
527        while stack:
528            scope = stack.pop()
529            result.append(scope)
530            stack.extend(
531                itertools.chain(
532                    scope.cte_scopes,
533                    scope.union_scopes,
534                    scope.table_scopes,
535                    scope.subquery_scopes,
536                )
537            )
538
539        yield from reversed(result)
540
541    def ref_count(self):
542        """
543        Count the number of times each scope in this tree is referenced.
544
545        Returns:
546            dict[int, int]: Mapping of Scope instance ID to reference count
547        """
548        scope_ref_count = defaultdict(lambda: 0)
549
550        for scope in self.traverse():
551            for _, source in scope.selected_sources.values():
552                scope_ref_count[id(source)] += 1
553
554            for name in scope._semi_anti_join_tables:
555                # semi/anti join sources are not actually selected but we still need to
556                # increment their ref count to avoid them being optimized away
557                if name in scope.sources:
558                    scope_ref_count[id(scope.sources[name])] += 1
559
560        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.SetOperation): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • cte_sources (dict[str, Scope]): Sources from CTES
  • outer_columns (list[str]): If this is a derived table or CTE, and the outer query defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_columns
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_columns=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None, cte_sources=None, can_be_correlated=None)
60    def __init__(
61        self,
62        expression,
63        sources=None,
64        outer_columns=None,
65        parent=None,
66        scope_type=ScopeType.ROOT,
67        lateral_sources=None,
68        cte_sources=None,
69        can_be_correlated=None,
70    ):
71        self.expression = expression
72        self.sources = sources or {}
73        self.lateral_sources = lateral_sources or {}
74        self.cte_sources = cte_sources or {}
75        self.sources.update(self.lateral_sources)
76        self.sources.update(self.cte_sources)
77        self.outer_columns = outer_columns or []
78        self.parent = parent
79        self.scope_type = scope_type
80        self.subquery_scopes = []
81        self.derived_table_scopes = []
82        self.table_scopes = []
83        self.cte_scopes = []
84        self.union_scopes = []
85        self.udtf_scopes = []
86        self.can_be_correlated = can_be_correlated
87        self.clear_cache()
expression
sources
lateral_sources
cte_sources
outer_columns
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
can_be_correlated
def clear_cache(self):
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._local_columns = None
103        self._join_hints = None
104        self._pivots = None
105        self._references = None
106        self._semi_anti_join_tables = None
107        self._column_index = None
def branch( self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs):
109    def branch(
110        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
111    ):
112        """Branch from the current scope to a new, inner scope"""
113        return Scope(
114            expression=expression.unnest(),
115            sources=sources.copy() if sources else None,
116            parent=self,
117            scope_type=scope_type,
118            cte_sources={**self.cte_sources, **(cte_sources or {})},
119            lateral_sources=lateral_sources.copy() if lateral_sources else None,
120            can_be_correlated=self.can_be_correlated
121            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
122            **kwargs,
123        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
176    def walk(self, bfs=True, prune=None):
177        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
179    def find(self, *expression_types, bfs=True):
180        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
182    def find_all(self, *expression_types, bfs=True):
183        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
185    def replace(self, old, new):
186        """
187        Replace `old` with `new`.
188
189        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
190
191        Args:
192            old (exp.Expression): old node
193            new (exp.Expression): new node
194        """
195        old.replace(new)
196        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables
198    @property
199    def tables(self):
200        """
201        List of tables in this scope.
202
203        Returns:
204            list[exp.Table]: tables
205        """
206        self._ensure_collected()
207        return self._tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes
209    @property
210    def ctes(self):
211        """
212        List of CTEs in this scope.
213
214        Returns:
215            list[exp.CTE]: ctes
216        """
217        self._ensure_collected()
218        return self._ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables
220    @property
221    def derived_tables(self):
222        """
223        List of derived tables in this scope.
224
225        For example:
226            SELECT * FROM (SELECT ...) <- that's a derived table
227
228        Returns:
229            list[exp.Subquery]: derived tables
230        """
231        self._ensure_collected()
232        return self._derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs
234    @property
235    def udtfs(self):
236        """
237        List of "User Defined Tabular Functions" in this scope.
238
239        Returns:
240            list[exp.UDTF]: UDTFs
241        """
242        self._ensure_collected()
243        return self._udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries
245    @property
246    def subqueries(self):
247        """
248        List of subqueries in this scope.
249
250        For example:
251            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
252
253        Returns:
254            list[exp.Select | exp.SetOperation]: subqueries
255        """
256        self._ensure_collected()
257        return self._subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Select | exp.SetOperation]: subqueries

259    @property
260    def stars(self) -> t.List[exp.Column | exp.Dot]:
261        """
262        List of star expressions (columns or dots) in this scope.
263        """
264        self._ensure_collected()
265        return self._stars

List of star expressions (columns or dots) in this scope.

column_index: Set[int]
267    @property
268    def column_index(self) -> t.Set[int]:
269        """
270        Set of column object IDs that belong to this scope's expression.
271        """
272        self._ensure_collected()
273        return self._column_index

Set of column object IDs that belong to this scope's expression.

columns
275    @property
276    def columns(self):
277        """
278        List of columns in this scope.
279
280        Returns:
281            list[exp.Column]: Column instances in this scope, plus any
282                Columns that reference this scope from correlated subqueries.
283        """
284        if self._columns is None:
285            self._ensure_collected()
286            columns = self._raw_columns
287
288            external_columns = [
289                column
290                for scope in itertools.chain(
291                    self.subquery_scopes,
292                    self.udtf_scopes,
293                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
294                )
295                for column in scope.external_columns
296            ]
297
298            named_selects = set(self.expression.named_selects)
299
300            self._columns = []
301            for column in columns + external_columns:
302                ancestor = column.find_ancestor(
303                    exp.Select,
304                    exp.Qualify,
305                    exp.Order,
306                    exp.Having,
307                    exp.Hint,
308                    exp.Table,
309                    exp.Star,
310                    exp.Distinct,
311                )
312                if (
313                    not ancestor
314                    or column.table
315                    or isinstance(ancestor, exp.Select)
316                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
317                    or (
318                        isinstance(ancestor, (exp.Order, exp.Distinct))
319                        and (
320                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
321                            or not isinstance(ancestor.parent, exp.Select)
322                            or column.name not in named_selects
323                        )
324                    )
325                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_")
326                ):
327                    self._columns.append(column)
328
329        return self._columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

table_columns
331    @property
332    def table_columns(self):
333        if self._table_columns is None:
334            self._ensure_collected()
335
336        return self._table_columns
selected_sources
338    @property
339    def selected_sources(self):
340        """
341        Mapping of nodes and sources that are actually selected from in this scope.
342
343        That is, all tables in a schema are selectable at any point. But a
344        table only becomes a selected source if it's included in a FROM or JOIN clause.
345
346        Returns:
347            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
348        """
349        if self._selected_sources is None:
350            result = {}
351
352            for name, node in self.references:
353                if name in self._semi_anti_join_tables:
354                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
355                    # selected source
356                    continue
357
358                if name in result:
359                    raise OptimizeError(f"Alias already used: {name}")
360                if name in self.sources:
361                    result[name] = (node, self.sources[name])
362
363            self._selected_sources = result
364        return self._selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
366    @property
367    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
368        if self._references is None:
369            self._references = []
370
371            for table in self.tables:
372                self._references.append((table.alias_or_name, table))
373            for expression in itertools.chain(self.derived_tables, self.udtfs):
374                self._references.append(
375                    (
376                        _get_source_alias(expression),
377                        expression if expression.args.get("pivots") else expression.unnest(),
378                    )
379                )
380
381        return self._references
external_columns
383    @property
384    def external_columns(self):
385        """
386        Columns that appear to reference sources in outer scopes.
387
388        Returns:
389            list[exp.Column]: Column instances that don't reference sources in the current scope.
390        """
391        if self._external_columns is None:
392            if isinstance(self.expression, exp.SetOperation):
393                left, right = self.union_scopes
394                self._external_columns = left.external_columns + right.external_columns
395            else:
396                self._external_columns = [
397                    c
398                    for c in self.columns
399                    if c.table not in self.sources and c.table not in self.semi_or_anti_join_tables
400                ]
401
402        return self._external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

local_columns
404    @property
405    def local_columns(self):
406        """
407        Columns in this scope that are not external.
408
409        Returns:
410            list[exp.Column]: Column instances that reference sources in the current scope.
411        """
412        if self._local_columns is None:
413            external_columns = set(self.external_columns)
414            self._local_columns = [c for c in self.columns if c not in external_columns]
415
416        return self._local_columns

Columns in this scope that are not external.

Returns:

list[exp.Column]: Column instances that reference sources in the current scope.

unqualified_columns
418    @property
419    def unqualified_columns(self):
420        """
421        Unqualified columns in the current scope.
422
423        Returns:
424             list[exp.Column]: Unqualified columns
425        """
426        return [c for c in self.columns if not c.table]

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints
428    @property
429    def join_hints(self):
430        """
431        Hints that exist in the scope that reference tables
432
433        Returns:
434            list[exp.JoinHint]: Join hints that are referenced within the scope
435        """
436        if self._join_hints is None:
437            return []
438        return self._join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
440    @property
441    def pivots(self):
442        if not self._pivots:
443            self._pivots = [
444                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
445            ]
446
447        return self._pivots
semi_or_anti_join_tables
449    @property
450    def semi_or_anti_join_tables(self):
451        return self._semi_anti_join_tables or set()
def source_columns(self, source_name):
453    def source_columns(self, source_name):
454        """
455        Get all columns in the current scope for a particular source.
456
457        Args:
458            source_name (str): Name of the source
459        Returns:
460            list[exp.Column]: Column instances that reference `source_name`
461        """
462        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery
464    @property
465    def is_subquery(self):
466        """Determine if this scope is a subquery"""
467        return self.scope_type == ScopeType.SUBQUERY

Determine if this scope is a subquery

is_derived_table
469    @property
470    def is_derived_table(self):
471        """Determine if this scope is a derived table"""
472        return self.scope_type == ScopeType.DERIVED_TABLE

Determine if this scope is a derived table

is_union
474    @property
475    def is_union(self):
476        """Determine if this scope is a union"""
477        return self.scope_type == ScopeType.UNION

Determine if this scope is a union

is_cte
479    @property
480    def is_cte(self):
481        """Determine if this scope is a common table expression"""
482        return self.scope_type == ScopeType.CTE

Determine if this scope is a common table expression

is_root
484    @property
485    def is_root(self):
486        """Determine if this is the root scope"""
487        return self.scope_type == ScopeType.ROOT

Determine if this is the root scope

is_udtf
489    @property
490    def is_udtf(self):
491        """Determine if this scope is a UDTF (User Defined Table Function)"""
492        return self.scope_type == ScopeType.UDTF

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery
494    @property
495    def is_correlated_subquery(self):
496        """Determine if this scope is a correlated subquery"""
497        return bool(self.can_be_correlated and self.external_columns)

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
499    def rename_source(self, old_name, new_name):
500        """Rename a source in this scope"""
501        old_name = old_name or ""
502        if old_name in self.sources:
503            self.sources[new_name] = self.sources.pop(old_name)

Rename a source in this scope

def add_source(self, name, source):
505    def add_source(self, name, source):
506        """Add a source to this scope"""
507        self.sources[name] = source
508        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
510    def remove_source(self, name):
511        """Remove a source from this scope"""
512        self.sources.pop(name, None)
513        self.clear_cache()

Remove a source from this scope

def traverse(self):
518    def traverse(self):
519        """
520        Traverse the scope tree from this node.
521
522        Yields:
523            Scope: scope instances in depth-first-search post-order
524        """
525        stack = [self]
526        result = []
527        while stack:
528            scope = stack.pop()
529            result.append(scope)
530            stack.extend(
531                itertools.chain(
532                    scope.cte_scopes,
533                    scope.union_scopes,
534                    scope.table_scopes,
535                    scope.subquery_scopes,
536                )
537            )
538
539        yield from reversed(result)

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
541    def ref_count(self):
542        """
543        Count the number of times each scope in this tree is referenced.
544
545        Returns:
546            dict[int, int]: Mapping of Scope instance ID to reference count
547        """
548        scope_ref_count = defaultdict(lambda: 0)
549
550        for scope in self.traverse():
551            for _, source in scope.selected_sources.values():
552                scope_ref_count[id(source)] += 1
553
554            for name in scope._semi_anti_join_tables:
555                # semi/anti join sources are not actually selected but we still need to
556                # increment their ref count to avoid them being optimized away
557                if name in scope.sources:
558                    scope_ref_count[id(scope.sources[name])] += 1
559
560        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
563def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
564    """
565    Traverse an expression by its "scopes".
566
567    "Scope" represents the current context of a Select statement.
568
569    This is helpful for optimizing queries, where we need more information than
570    the expression tree itself. For example, we might care about the source
571    names within a subquery. Returns a list because a generator could result in
572    incomplete properties which is confusing.
573
574    Examples:
575        >>> import sqlglot
576        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
577        >>> scopes = traverse_scope(expression)
578        >>> scopes[0].expression.sql(), list(scopes[0].sources)
579        ('SELECT a FROM x', ['x'])
580        >>> scopes[1].expression.sql(), list(scopes[1].sources)
581        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
582
583    Args:
584        expression: Expression to traverse
585
586    Returns:
587        A list of the created scope instances
588    """
589    if isinstance(expression, TRAVERSABLES):
590        return list(_traverse_scope(Scope(expression)))
591    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression: Expression to traverse
Returns:

A list of the created scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
594def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
595    """
596    Build a scope tree.
597
598    Args:
599        expression: Expression to build the scope tree for.
600
601    Returns:
602        The root scope
603    """
604    return seq_get(traverse_scope(expression), -1)

Build a scope tree.

Arguments:
  • expression: Expression to build the scope tree for.
Returns:

The root scope

def walk_in_scope(expression, bfs=True, prune=None):
869def walk_in_scope(expression, bfs=True, prune=None):
870    """
871    Returns a generator object which visits all nodes in the syntrax tree, stopping at
872    nodes that start child scopes.
873
874    Args:
875        expression (exp.Expression):
876        bfs (bool): if set to True the BFS traversal order will be applied,
877            otherwise the DFS traversal will be used instead.
878        prune ((node, parent, arg_key) -> bool): callable that returns True if
879            the generator should stop traversing this branch of the tree.
880
881    Yields:
882        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
883    """
884    # We'll use this variable to pass state into the dfs generator.
885    # Whenever we set it to True, we exclude a subtree from traversal.
886    crossed_scope_boundary = False
887
888    for node in expression.walk(
889        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
890    ):
891        crossed_scope_boundary = False
892
893        yield node
894
895        if node is expression:
896            continue
897
898        if (
899            isinstance(node, exp.CTE)
900            or (isinstance(node.parent, (exp.From, exp.Join)) and _is_derived_table(node))
901            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
902            or isinstance(node, exp.UNWRAPPED_QUERIES)
903        ):
904            crossed_scope_boundary = True
905
906            if isinstance(node, (exp.Subquery, exp.UDTF)):
907                # The following args are not actually in the inner scope, so we should visit them
908                for key in ("joins", "laterals", "pivots"):
909                    for arg in node.args.get(key) or []:
910                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
913def find_all_in_scope(expression, expression_types, bfs=True):
914    """
915    Returns a generator object which visits all nodes in this scope and only yields those that
916    match at least one of the specified expression types.
917
918    This does NOT traverse into subscopes.
919
920    Args:
921        expression (exp.Expression):
922        expression_types (tuple[type]|type): the expression type(s) to match.
923        bfs (bool): True to use breadth-first search, False to use depth-first.
924
925    Yields:
926        exp.Expression: nodes
927    """
928    for expression in walk_in_scope(expression, bfs=bfs):
929        if isinstance(expression, tuple(ensure_collection(expression_types))):
930            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
933def find_in_scope(expression, expression_types, bfs=True):
934    """
935    Returns the first node in this scope which matches at least one of the specified types.
936
937    This does NOT traverse into subscopes.
938
939    Args:
940        expression (exp.Expression):
941        expression_types (tuple[type]|type): the expression type(s) to match.
942        bfs (bool): True to use breadth-first search, False to use depth-first.
943
944    Returns:
945        exp.Expression: the node which matches the criteria or None if no node matching
946        the criteria was found.
947    """
948    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.