Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name, seq_get
 12
 13logger = logging.getLogger("sqlglot")
 14
 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
 16
 17
 18class ScopeType(Enum):
 19    ROOT = auto()
 20    SUBQUERY = auto()
 21    DERIVED_TABLE = auto()
 22    CTE = auto()
 23    UNION = auto()
 24    UDTF = auto()
 25
 26
 27class Scope:
 28    """
 29    Selection scope.
 30
 31    Attributes:
 32        expression (exp.Select|exp.SetOperation): Root expression of this scope
 33        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 34            a Table expression or another Scope instance. For example:
 35                SELECT * FROM x                     {"x": Table(this="x")}
 36                SELECT * FROM x AS y                {"y": Table(this="x")}
 37                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 38        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 39            For example:
 40                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 41            The LATERAL VIEW EXPLODE gets x as a source.
 42        cte_sources (dict[str, Scope]): Sources from CTES
 43        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 44            defines a column list for the alias of this scope, this is that list of columns.
 45            For example:
 46                SELECT * FROM (SELECT ...) AS y(col1, col2)
 47            The inner query would have `["col1", "col2"]` for its `outer_columns`
 48        parent (Scope): Parent scope
 49        scope_type (ScopeType): Type of this scope, relative to it's parent
 50        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 51        cte_scopes (list[Scope]): List of all child scopes for CTEs
 52        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 53        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 54        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 55        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 56            a list of the left and right child scopes.
 57    """
 58
 59    def __init__(
 60        self,
 61        expression,
 62        sources=None,
 63        outer_columns=None,
 64        parent=None,
 65        scope_type=ScopeType.ROOT,
 66        lateral_sources=None,
 67        cte_sources=None,
 68        can_be_correlated=None,
 69    ):
 70        self.expression = expression
 71        self.sources = sources or {}
 72        self.lateral_sources = lateral_sources or {}
 73        self.cte_sources = cte_sources or {}
 74        self.sources.update(self.lateral_sources)
 75        self.sources.update(self.cte_sources)
 76        self.outer_columns = outer_columns or []
 77        self.parent = parent
 78        self.scope_type = scope_type
 79        self.subquery_scopes = []
 80        self.derived_table_scopes = []
 81        self.table_scopes = []
 82        self.cte_scopes = []
 83        self.union_scopes = []
 84        self.udtf_scopes = []
 85        self.can_be_correlated = can_be_correlated
 86        self.clear_cache()
 87
 88    def clear_cache(self):
 89        self._collected = False
 90        self._raw_columns = None
 91        self._table_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._join_hints = None
102        self._pivots = None
103        self._references = None
104        self._semi_anti_join_tables = None
105
106    def branch(
107        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
108    ):
109        """Branch from the current scope to a new, inner scope"""
110        return Scope(
111            expression=expression.unnest(),
112            sources=sources.copy() if sources else None,
113            parent=self,
114            scope_type=scope_type,
115            cte_sources={**self.cte_sources, **(cte_sources or {})},
116            lateral_sources=lateral_sources.copy() if lateral_sources else None,
117            can_be_correlated=self.can_be_correlated
118            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
119            **kwargs,
120        )
121
122    def _collect(self):
123        self._tables = []
124        self._ctes = []
125        self._subqueries = []
126        self._derived_tables = []
127        self._udtfs = []
128        self._raw_columns = []
129        self._table_columns = []
130        self._stars = []
131        self._join_hints = []
132        self._semi_anti_join_tables = set()
133
134        for node in self.walk(bfs=False):
135            if node is self.expression:
136                continue
137
138            if isinstance(node, exp.Dot) and node.is_star:
139                self._stars.append(node)
140            elif isinstance(node, exp.Column):
141                if isinstance(node.this, exp.Star):
142                    self._stars.append(node)
143                else:
144                    self._raw_columns.append(node)
145            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
146                parent = node.parent
147                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
148                    self._semi_anti_join_tables.add(node.alias_or_name)
149
150                self._tables.append(node)
151            elif isinstance(node, exp.JoinHint):
152                self._join_hints.append(node)
153            elif isinstance(node, exp.UDTF):
154                self._udtfs.append(node)
155            elif isinstance(node, exp.CTE):
156                self._ctes.append(node)
157            elif _is_derived_table(node) and _is_from_or_join(node):
158                self._derived_tables.append(node)
159            elif isinstance(node, exp.UNWRAPPED_QUERIES):
160                self._subqueries.append(node)
161            elif isinstance(node, exp.TableColumn):
162                self._table_columns.append(node)
163
164        self._collected = True
165
166    def _ensure_collected(self):
167        if not self._collected:
168            self._collect()
169
170    def walk(self, bfs=True, prune=None):
171        return walk_in_scope(self.expression, bfs=bfs, prune=None)
172
173    def find(self, *expression_types, bfs=True):
174        return find_in_scope(self.expression, expression_types, bfs=bfs)
175
176    def find_all(self, *expression_types, bfs=True):
177        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
178
179    def replace(self, old, new):
180        """
181        Replace `old` with `new`.
182
183        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
184
185        Args:
186            old (exp.Expression): old node
187            new (exp.Expression): new node
188        """
189        old.replace(new)
190        self.clear_cache()
191
192    @property
193    def tables(self):
194        """
195        List of tables in this scope.
196
197        Returns:
198            list[exp.Table]: tables
199        """
200        self._ensure_collected()
201        return self._tables
202
203    @property
204    def ctes(self):
205        """
206        List of CTEs in this scope.
207
208        Returns:
209            list[exp.CTE]: ctes
210        """
211        self._ensure_collected()
212        return self._ctes
213
214    @property
215    def derived_tables(self):
216        """
217        List of derived tables in this scope.
218
219        For example:
220            SELECT * FROM (SELECT ...) <- that's a derived table
221
222        Returns:
223            list[exp.Subquery]: derived tables
224        """
225        self._ensure_collected()
226        return self._derived_tables
227
228    @property
229    def udtfs(self):
230        """
231        List of "User Defined Tabular Functions" in this scope.
232
233        Returns:
234            list[exp.UDTF]: UDTFs
235        """
236        self._ensure_collected()
237        return self._udtfs
238
239    @property
240    def subqueries(self):
241        """
242        List of subqueries in this scope.
243
244        For example:
245            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
246
247        Returns:
248            list[exp.Select | exp.SetOperation]: subqueries
249        """
250        self._ensure_collected()
251        return self._subqueries
252
253    @property
254    def stars(self) -> t.List[exp.Column | exp.Dot]:
255        """
256        List of star expressions (columns or dots) in this scope.
257        """
258        self._ensure_collected()
259        return self._stars
260
261    @property
262    def columns(self):
263        """
264        List of columns in this scope.
265
266        Returns:
267            list[exp.Column]: Column instances in this scope, plus any
268                Columns that reference this scope from correlated subqueries.
269        """
270        if self._columns is None:
271            self._ensure_collected()
272            columns = self._raw_columns
273
274            external_columns = [
275                column
276                for scope in itertools.chain(
277                    self.subquery_scopes,
278                    self.udtf_scopes,
279                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
280                )
281                for column in scope.external_columns
282            ]
283
284            named_selects = set(self.expression.named_selects)
285
286            self._columns = []
287            for column in columns + external_columns:
288                ancestor = column.find_ancestor(
289                    exp.Select,
290                    exp.Qualify,
291                    exp.Order,
292                    exp.Having,
293                    exp.Hint,
294                    exp.Table,
295                    exp.Star,
296                    exp.Distinct,
297                )
298                if (
299                    not ancestor
300                    or column.table
301                    or isinstance(ancestor, exp.Select)
302                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
303                    or (
304                        isinstance(ancestor, (exp.Order, exp.Distinct))
305                        and (
306                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
307                            or not isinstance(ancestor.parent, exp.Select)
308                            or column.name not in named_selects
309                        )
310                    )
311                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
312                ):
313                    self._columns.append(column)
314
315        return self._columns
316
317    @property
318    def table_columns(self):
319        if self._table_columns is None:
320            self._ensure_collected()
321
322        return self._table_columns
323
324    @property
325    def selected_sources(self):
326        """
327        Mapping of nodes and sources that are actually selected from in this scope.
328
329        That is, all tables in a schema are selectable at any point. But a
330        table only becomes a selected source if it's included in a FROM or JOIN clause.
331
332        Returns:
333            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
334        """
335        if self._selected_sources is None:
336            result = {}
337
338            for name, node in self.references:
339                if name in self._semi_anti_join_tables:
340                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
341                    # selected source
342                    continue
343
344                if name in result:
345                    raise OptimizeError(f"Alias already used: {name}")
346                if name in self.sources:
347                    result[name] = (node, self.sources[name])
348
349            self._selected_sources = result
350        return self._selected_sources
351
352    @property
353    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
354        if self._references is None:
355            self._references = []
356
357            for table in self.tables:
358                self._references.append((table.alias_or_name, table))
359            for expression in itertools.chain(self.derived_tables, self.udtfs):
360                self._references.append(
361                    (
362                        _get_source_alias(expression),
363                        expression if expression.args.get("pivots") else expression.unnest(),
364                    )
365                )
366
367        return self._references
368
369    @property
370    def external_columns(self):
371        """
372        Columns that appear to reference sources in outer scopes.
373
374        Returns:
375            list[exp.Column]: Column instances that don't reference
376                sources in the current scope.
377        """
378        if self._external_columns is None:
379            if isinstance(self.expression, exp.SetOperation):
380                left, right = self.union_scopes
381                self._external_columns = left.external_columns + right.external_columns
382            else:
383                self._external_columns = [
384                    c
385                    for c in self.columns
386                    if c.table not in self.selected_sources
387                    and c.table not in self.semi_or_anti_join_tables
388                ]
389
390        return self._external_columns
391
392    @property
393    def unqualified_columns(self):
394        """
395        Unqualified columns in the current scope.
396
397        Returns:
398             list[exp.Column]: Unqualified columns
399        """
400        return [c for c in self.columns if not c.table]
401
402    @property
403    def join_hints(self):
404        """
405        Hints that exist in the scope that reference tables
406
407        Returns:
408            list[exp.JoinHint]: Join hints that are referenced within the scope
409        """
410        if self._join_hints is None:
411            return []
412        return self._join_hints
413
414    @property
415    def pivots(self):
416        if not self._pivots:
417            self._pivots = [
418                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
419            ]
420
421        return self._pivots
422
423    @property
424    def semi_or_anti_join_tables(self):
425        return self._semi_anti_join_tables or set()
426
427    def source_columns(self, source_name):
428        """
429        Get all columns in the current scope for a particular source.
430
431        Args:
432            source_name (str): Name of the source
433        Returns:
434            list[exp.Column]: Column instances that reference `source_name`
435        """
436        return [column for column in self.columns if column.table == source_name]
437
438    @property
439    def is_subquery(self):
440        """Determine if this scope is a subquery"""
441        return self.scope_type == ScopeType.SUBQUERY
442
443    @property
444    def is_derived_table(self):
445        """Determine if this scope is a derived table"""
446        return self.scope_type == ScopeType.DERIVED_TABLE
447
448    @property
449    def is_union(self):
450        """Determine if this scope is a union"""
451        return self.scope_type == ScopeType.UNION
452
453    @property
454    def is_cte(self):
455        """Determine if this scope is a common table expression"""
456        return self.scope_type == ScopeType.CTE
457
458    @property
459    def is_root(self):
460        """Determine if this is the root scope"""
461        return self.scope_type == ScopeType.ROOT
462
463    @property
464    def is_udtf(self):
465        """Determine if this scope is a UDTF (User Defined Table Function)"""
466        return self.scope_type == ScopeType.UDTF
467
468    @property
469    def is_correlated_subquery(self):
470        """Determine if this scope is a correlated subquery"""
471        return bool(self.can_be_correlated and self.external_columns)
472
473    def rename_source(self, old_name, new_name):
474        """Rename a source in this scope"""
475        columns = self.sources.pop(old_name or "", [])
476        self.sources[new_name] = columns
477
478    def add_source(self, name, source):
479        """Add a source to this scope"""
480        self.sources[name] = source
481        self.clear_cache()
482
483    def remove_source(self, name):
484        """Remove a source from this scope"""
485        self.sources.pop(name, None)
486        self.clear_cache()
487
488    def __repr__(self):
489        return f"Scope<{self.expression.sql()}>"
490
491    def traverse(self):
492        """
493        Traverse the scope tree from this node.
494
495        Yields:
496            Scope: scope instances in depth-first-search post-order
497        """
498        stack = [self]
499        result = []
500        while stack:
501            scope = stack.pop()
502            result.append(scope)
503            stack.extend(
504                itertools.chain(
505                    scope.cte_scopes,
506                    scope.union_scopes,
507                    scope.table_scopes,
508                    scope.subquery_scopes,
509                )
510            )
511
512        yield from reversed(result)
513
514    def ref_count(self):
515        """
516        Count the number of times each scope in this tree is referenced.
517
518        Returns:
519            dict[int, int]: Mapping of Scope instance ID to reference count
520        """
521        scope_ref_count = defaultdict(lambda: 0)
522
523        for scope in self.traverse():
524            for _, source in scope.selected_sources.values():
525                scope_ref_count[id(source)] += 1
526
527            for name in scope._semi_anti_join_tables:
528                # semi/anti join sources are not actually selected but we still need to
529                # increment their ref count to avoid them being optimized away
530                if name in scope.sources:
531                    scope_ref_count[id(scope.sources[name])] += 1
532
533        return scope_ref_count
534
535
536def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
537    """
538    Traverse an expression by its "scopes".
539
540    "Scope" represents the current context of a Select statement.
541
542    This is helpful for optimizing queries, where we need more information than
543    the expression tree itself. For example, we might care about the source
544    names within a subquery. Returns a list because a generator could result in
545    incomplete properties which is confusing.
546
547    Examples:
548        >>> import sqlglot
549        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
550        >>> scopes = traverse_scope(expression)
551        >>> scopes[0].expression.sql(), list(scopes[0].sources)
552        ('SELECT a FROM x', ['x'])
553        >>> scopes[1].expression.sql(), list(scopes[1].sources)
554        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
555
556    Args:
557        expression: Expression to traverse
558
559    Returns:
560        A list of the created scope instances
561    """
562    if isinstance(expression, TRAVERSABLES):
563        return list(_traverse_scope(Scope(expression)))
564    return []
565
566
567def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
568    """
569    Build a scope tree.
570
571    Args:
572        expression: Expression to build the scope tree for.
573
574    Returns:
575        The root scope
576    """
577    return seq_get(traverse_scope(expression), -1)
578
579
580def _traverse_scope(scope):
581    expression = scope.expression
582
583    if isinstance(expression, exp.Select):
584        yield from _traverse_select(scope)
585    elif isinstance(expression, exp.SetOperation):
586        yield from _traverse_ctes(scope)
587        yield from _traverse_union(scope)
588        return
589    elif isinstance(expression, exp.Subquery):
590        if scope.is_root:
591            yield from _traverse_select(scope)
592        else:
593            yield from _traverse_subqueries(scope)
594    elif isinstance(expression, exp.Table):
595        yield from _traverse_tables(scope)
596    elif isinstance(expression, exp.UDTF):
597        yield from _traverse_udtfs(scope)
598    elif isinstance(expression, exp.DDL):
599        if isinstance(expression.expression, exp.Query):
600            yield from _traverse_ctes(scope)
601            yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
602        return
603    elif isinstance(expression, exp.DML):
604        yield from _traverse_ctes(scope)
605        for query in find_all_in_scope(expression, exp.Query):
606            # This check ensures we don't yield the CTE/nested queries twice
607            if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
608                yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
609        return
610    else:
611        logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
612        return
613
614    yield scope
615
616
617def _traverse_select(scope):
618    yield from _traverse_ctes(scope)
619    yield from _traverse_tables(scope)
620    yield from _traverse_subqueries(scope)
621
622
623def _traverse_union(scope):
624    prev_scope = None
625    union_scope_stack = [scope]
626    expression_stack = [scope.expression.right, scope.expression.left]
627
628    while expression_stack:
629        expression = expression_stack.pop()
630        union_scope = union_scope_stack[-1]
631
632        new_scope = union_scope.branch(
633            expression,
634            outer_columns=union_scope.outer_columns,
635            scope_type=ScopeType.UNION,
636        )
637
638        if isinstance(expression, exp.SetOperation):
639            yield from _traverse_ctes(new_scope)
640
641            union_scope_stack.append(new_scope)
642            expression_stack.extend([expression.right, expression.left])
643            continue
644
645        for scope in _traverse_scope(new_scope):
646            yield scope
647
648        if prev_scope:
649            union_scope_stack.pop()
650            union_scope.union_scopes = [prev_scope, scope]
651            prev_scope = union_scope
652
653            yield union_scope
654        else:
655            prev_scope = scope
656
657
658def _traverse_ctes(scope):
659    sources = {}
660
661    for cte in scope.ctes:
662        cte_name = cte.alias
663
664        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
665        # thus the recursive scope is the first section of the union.
666        with_ = scope.expression.args.get("with")
667        if with_ and with_.recursive:
668            union = cte.this
669
670            if isinstance(union, exp.SetOperation):
671                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
672
673        child_scope = None
674
675        for child_scope in _traverse_scope(
676            scope.branch(
677                cte.this,
678                cte_sources=sources,
679                outer_columns=cte.alias_column_names,
680                scope_type=ScopeType.CTE,
681            )
682        ):
683            yield child_scope
684
685        # append the final child_scope yielded
686        if child_scope:
687            sources[cte_name] = child_scope
688            scope.cte_scopes.append(child_scope)
689
690    scope.sources.update(sources)
691    scope.cte_sources.update(sources)
692
693
694def _is_derived_table(expression: exp.Subquery) -> bool:
695    """
696    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
697    as it doesn't introduce a new scope. If an alias is present, it shadows all names
698    under the Subquery, so that's one exception to this rule.
699    """
700    return isinstance(expression, exp.Subquery) and bool(
701        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
702    )
703
704
705def _is_from_or_join(expression: exp.Expression) -> bool:
706    """
707    Determine if `expression` is the FROM or JOIN clause of a SELECT statement.
708    """
709    parent = expression.parent
710
711    # Subqueries can be arbitrarily nested
712    while isinstance(parent, exp.Subquery):
713        parent = parent.parent
714
715    return isinstance(parent, (exp.From, exp.Join))
716
717
718def _traverse_tables(scope):
719    sources = {}
720
721    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
722    expressions = []
723    from_ = scope.expression.args.get("from")
724    if from_:
725        expressions.append(from_.this)
726
727    for join in scope.expression.args.get("joins") or []:
728        expressions.append(join.this)
729
730    if isinstance(scope.expression, exp.Table):
731        expressions.append(scope.expression)
732
733    expressions.extend(scope.expression.args.get("laterals") or [])
734
735    for expression in expressions:
736        if isinstance(expression, exp.Final):
737            expression = expression.this
738        if isinstance(expression, exp.Table):
739            table_name = expression.name
740            source_name = expression.alias_or_name
741
742            if table_name in scope.sources and not expression.db:
743                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
744                # it is pivoted, because then we get back a new table and hence a new source.
745                pivots = expression.args.get("pivots")
746                if pivots:
747                    sources[pivots[0].alias] = expression
748                else:
749                    sources[source_name] = scope.sources[table_name]
750            elif source_name in sources:
751                sources[find_new_name(sources, table_name)] = expression
752            else:
753                sources[source_name] = expression
754
755            # Make sure to not include the joins twice
756            if expression is not scope.expression:
757                expressions.extend(join.this for join in expression.args.get("joins") or [])
758
759            continue
760
761        if not isinstance(expression, exp.DerivedTable):
762            continue
763
764        if isinstance(expression, exp.UDTF):
765            lateral_sources = sources
766            scope_type = ScopeType.UDTF
767            scopes = scope.udtf_scopes
768        elif _is_derived_table(expression):
769            lateral_sources = None
770            scope_type = ScopeType.DERIVED_TABLE
771            scopes = scope.derived_table_scopes
772            expressions.extend(join.this for join in expression.args.get("joins") or [])
773        else:
774            # Makes sure we check for possible sources in nested table constructs
775            expressions.append(expression.this)
776            expressions.extend(join.this for join in expression.args.get("joins") or [])
777            continue
778
779        child_scope = None
780
781        for child_scope in _traverse_scope(
782            scope.branch(
783                expression,
784                lateral_sources=lateral_sources,
785                outer_columns=expression.alias_column_names,
786                scope_type=scope_type,
787            )
788        ):
789            yield child_scope
790
791            # Tables without aliases will be set as ""
792            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
793            # Until then, this means that only a single, unaliased derived table is allowed (rather,
794            # the latest one wins.
795            sources[_get_source_alias(expression)] = child_scope
796
797        # append the final child_scope yielded
798        if child_scope:
799            scopes.append(child_scope)
800            scope.table_scopes.append(child_scope)
801
802    scope.sources.update(sources)
803
804
805def _traverse_subqueries(scope):
806    for subquery in scope.subqueries:
807        top = None
808        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
809            yield child_scope
810            top = child_scope
811        scope.subquery_scopes.append(top)
812
813
814def _traverse_udtfs(scope):
815    if isinstance(scope.expression, exp.Unnest):
816        expressions = scope.expression.expressions
817    elif isinstance(scope.expression, exp.Lateral):
818        expressions = [scope.expression.this]
819    else:
820        expressions = []
821
822    sources = {}
823    for expression in expressions:
824        if _is_derived_table(expression):
825            top = None
826            for child_scope in _traverse_scope(
827                scope.branch(
828                    expression,
829                    scope_type=ScopeType.SUBQUERY,
830                    outer_columns=expression.alias_column_names,
831                )
832            ):
833                yield child_scope
834                top = child_scope
835                sources[_get_source_alias(expression)] = child_scope
836
837            scope.subquery_scopes.append(top)
838
839    scope.sources.update(sources)
840
841
842def walk_in_scope(expression, bfs=True, prune=None):
843    """
844    Returns a generator object which visits all nodes in the syntrax tree, stopping at
845    nodes that start child scopes.
846
847    Args:
848        expression (exp.Expression):
849        bfs (bool): if set to True the BFS traversal order will be applied,
850            otherwise the DFS traversal will be used instead.
851        prune ((node, parent, arg_key) -> bool): callable that returns True if
852            the generator should stop traversing this branch of the tree.
853
854    Yields:
855        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
856    """
857    # We'll use this variable to pass state into the dfs generator.
858    # Whenever we set it to True, we exclude a subtree from traversal.
859    crossed_scope_boundary = False
860
861    for node in expression.walk(
862        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
863    ):
864        crossed_scope_boundary = False
865
866        yield node
867
868        if node is expression:
869            continue
870
871        if (
872            isinstance(node, exp.CTE)
873            or (
874                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
875                and _is_derived_table(node)
876            )
877            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
878            or isinstance(node, exp.UNWRAPPED_QUERIES)
879        ):
880            crossed_scope_boundary = True
881
882            if isinstance(node, (exp.Subquery, exp.UDTF)):
883                # The following args are not actually in the inner scope, so we should visit them
884                for key in ("joins", "laterals", "pivots"):
885                    for arg in node.args.get(key) or []:
886                        yield from walk_in_scope(arg, bfs=bfs)
887
888
889def find_all_in_scope(expression, expression_types, bfs=True):
890    """
891    Returns a generator object which visits all nodes in this scope and only yields those that
892    match at least one of the specified expression types.
893
894    This does NOT traverse into subscopes.
895
896    Args:
897        expression (exp.Expression):
898        expression_types (tuple[type]|type): the expression type(s) to match.
899        bfs (bool): True to use breadth-first search, False to use depth-first.
900
901    Yields:
902        exp.Expression: nodes
903    """
904    for expression in walk_in_scope(expression, bfs=bfs):
905        if isinstance(expression, tuple(ensure_collection(expression_types))):
906            yield expression
907
908
909def find_in_scope(expression, expression_types, bfs=True):
910    """
911    Returns the first node in this scope which matches at least one of the specified types.
912
913    This does NOT traverse into subscopes.
914
915    Args:
916        expression (exp.Expression):
917        expression_types (tuple[type]|type): the expression type(s) to match.
918        bfs (bool): True to use breadth-first search, False to use depth-first.
919
920    Returns:
921        exp.Expression: the node which matches the criteria or None if no node matching
922        the criteria was found.
923    """
924    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
925
926
927def _get_source_alias(expression):
928    alias_arg = expression.args.get("alias")
929    alias_name = expression.alias
930
931    if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1:
932        alias_name = alias_arg.columns[0].name
933
934    return alias_name
logger = <Logger sqlglot (WARNING)>
TRAVERSABLES = (<class 'sqlglot.expressions.Query'>, <class 'sqlglot.expressions.DDL'>, <class 'sqlglot.expressions.DML'>)
class ScopeType(enum.Enum):
19class ScopeType(Enum):
20    ROOT = auto()
21    SUBQUERY = auto()
22    DERIVED_TABLE = auto()
23    CTE = auto()
24    UNION = auto()
25    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
class Scope:
 28class Scope:
 29    """
 30    Selection scope.
 31
 32    Attributes:
 33        expression (exp.Select|exp.SetOperation): Root expression of this scope
 34        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 35            a Table expression or another Scope instance. For example:
 36                SELECT * FROM x                     {"x": Table(this="x")}
 37                SELECT * FROM x AS y                {"y": Table(this="x")}
 38                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 39        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 40            For example:
 41                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 42            The LATERAL VIEW EXPLODE gets x as a source.
 43        cte_sources (dict[str, Scope]): Sources from CTES
 44        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 45            defines a column list for the alias of this scope, this is that list of columns.
 46            For example:
 47                SELECT * FROM (SELECT ...) AS y(col1, col2)
 48            The inner query would have `["col1", "col2"]` for its `outer_columns`
 49        parent (Scope): Parent scope
 50        scope_type (ScopeType): Type of this scope, relative to it's parent
 51        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 52        cte_scopes (list[Scope]): List of all child scopes for CTEs
 53        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 54        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 55        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 56        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 57            a list of the left and right child scopes.
 58    """
 59
 60    def __init__(
 61        self,
 62        expression,
 63        sources=None,
 64        outer_columns=None,
 65        parent=None,
 66        scope_type=ScopeType.ROOT,
 67        lateral_sources=None,
 68        cte_sources=None,
 69        can_be_correlated=None,
 70    ):
 71        self.expression = expression
 72        self.sources = sources or {}
 73        self.lateral_sources = lateral_sources or {}
 74        self.cte_sources = cte_sources or {}
 75        self.sources.update(self.lateral_sources)
 76        self.sources.update(self.cte_sources)
 77        self.outer_columns = outer_columns or []
 78        self.parent = parent
 79        self.scope_type = scope_type
 80        self.subquery_scopes = []
 81        self.derived_table_scopes = []
 82        self.table_scopes = []
 83        self.cte_scopes = []
 84        self.union_scopes = []
 85        self.udtf_scopes = []
 86        self.can_be_correlated = can_be_correlated
 87        self.clear_cache()
 88
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._join_hints = None
103        self._pivots = None
104        self._references = None
105        self._semi_anti_join_tables = None
106
107    def branch(
108        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
109    ):
110        """Branch from the current scope to a new, inner scope"""
111        return Scope(
112            expression=expression.unnest(),
113            sources=sources.copy() if sources else None,
114            parent=self,
115            scope_type=scope_type,
116            cte_sources={**self.cte_sources, **(cte_sources or {})},
117            lateral_sources=lateral_sources.copy() if lateral_sources else None,
118            can_be_correlated=self.can_be_correlated
119            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
120            **kwargs,
121        )
122
123    def _collect(self):
124        self._tables = []
125        self._ctes = []
126        self._subqueries = []
127        self._derived_tables = []
128        self._udtfs = []
129        self._raw_columns = []
130        self._table_columns = []
131        self._stars = []
132        self._join_hints = []
133        self._semi_anti_join_tables = set()
134
135        for node in self.walk(bfs=False):
136            if node is self.expression:
137                continue
138
139            if isinstance(node, exp.Dot) and node.is_star:
140                self._stars.append(node)
141            elif isinstance(node, exp.Column):
142                if isinstance(node.this, exp.Star):
143                    self._stars.append(node)
144                else:
145                    self._raw_columns.append(node)
146            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
147                parent = node.parent
148                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
149                    self._semi_anti_join_tables.add(node.alias_or_name)
150
151                self._tables.append(node)
152            elif isinstance(node, exp.JoinHint):
153                self._join_hints.append(node)
154            elif isinstance(node, exp.UDTF):
155                self._udtfs.append(node)
156            elif isinstance(node, exp.CTE):
157                self._ctes.append(node)
158            elif _is_derived_table(node) and _is_from_or_join(node):
159                self._derived_tables.append(node)
160            elif isinstance(node, exp.UNWRAPPED_QUERIES):
161                self._subqueries.append(node)
162            elif isinstance(node, exp.TableColumn):
163                self._table_columns.append(node)
164
165        self._collected = True
166
167    def _ensure_collected(self):
168        if not self._collected:
169            self._collect()
170
171    def walk(self, bfs=True, prune=None):
172        return walk_in_scope(self.expression, bfs=bfs, prune=None)
173
174    def find(self, *expression_types, bfs=True):
175        return find_in_scope(self.expression, expression_types, bfs=bfs)
176
177    def find_all(self, *expression_types, bfs=True):
178        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
179
180    def replace(self, old, new):
181        """
182        Replace `old` with `new`.
183
184        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
185
186        Args:
187            old (exp.Expression): old node
188            new (exp.Expression): new node
189        """
190        old.replace(new)
191        self.clear_cache()
192
193    @property
194    def tables(self):
195        """
196        List of tables in this scope.
197
198        Returns:
199            list[exp.Table]: tables
200        """
201        self._ensure_collected()
202        return self._tables
203
204    @property
205    def ctes(self):
206        """
207        List of CTEs in this scope.
208
209        Returns:
210            list[exp.CTE]: ctes
211        """
212        self._ensure_collected()
213        return self._ctes
214
215    @property
216    def derived_tables(self):
217        """
218        List of derived tables in this scope.
219
220        For example:
221            SELECT * FROM (SELECT ...) <- that's a derived table
222
223        Returns:
224            list[exp.Subquery]: derived tables
225        """
226        self._ensure_collected()
227        return self._derived_tables
228
229    @property
230    def udtfs(self):
231        """
232        List of "User Defined Tabular Functions" in this scope.
233
234        Returns:
235            list[exp.UDTF]: UDTFs
236        """
237        self._ensure_collected()
238        return self._udtfs
239
240    @property
241    def subqueries(self):
242        """
243        List of subqueries in this scope.
244
245        For example:
246            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
247
248        Returns:
249            list[exp.Select | exp.SetOperation]: subqueries
250        """
251        self._ensure_collected()
252        return self._subqueries
253
254    @property
255    def stars(self) -> t.List[exp.Column | exp.Dot]:
256        """
257        List of star expressions (columns or dots) in this scope.
258        """
259        self._ensure_collected()
260        return self._stars
261
262    @property
263    def columns(self):
264        """
265        List of columns in this scope.
266
267        Returns:
268            list[exp.Column]: Column instances in this scope, plus any
269                Columns that reference this scope from correlated subqueries.
270        """
271        if self._columns is None:
272            self._ensure_collected()
273            columns = self._raw_columns
274
275            external_columns = [
276                column
277                for scope in itertools.chain(
278                    self.subquery_scopes,
279                    self.udtf_scopes,
280                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
281                )
282                for column in scope.external_columns
283            ]
284
285            named_selects = set(self.expression.named_selects)
286
287            self._columns = []
288            for column in columns + external_columns:
289                ancestor = column.find_ancestor(
290                    exp.Select,
291                    exp.Qualify,
292                    exp.Order,
293                    exp.Having,
294                    exp.Hint,
295                    exp.Table,
296                    exp.Star,
297                    exp.Distinct,
298                )
299                if (
300                    not ancestor
301                    or column.table
302                    or isinstance(ancestor, exp.Select)
303                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
304                    or (
305                        isinstance(ancestor, (exp.Order, exp.Distinct))
306                        and (
307                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
308                            or not isinstance(ancestor.parent, exp.Select)
309                            or column.name not in named_selects
310                        )
311                    )
312                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
313                ):
314                    self._columns.append(column)
315
316        return self._columns
317
318    @property
319    def table_columns(self):
320        if self._table_columns is None:
321            self._ensure_collected()
322
323        return self._table_columns
324
325    @property
326    def selected_sources(self):
327        """
328        Mapping of nodes and sources that are actually selected from in this scope.
329
330        That is, all tables in a schema are selectable at any point. But a
331        table only becomes a selected source if it's included in a FROM or JOIN clause.
332
333        Returns:
334            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
335        """
336        if self._selected_sources is None:
337            result = {}
338
339            for name, node in self.references:
340                if name in self._semi_anti_join_tables:
341                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
342                    # selected source
343                    continue
344
345                if name in result:
346                    raise OptimizeError(f"Alias already used: {name}")
347                if name in self.sources:
348                    result[name] = (node, self.sources[name])
349
350            self._selected_sources = result
351        return self._selected_sources
352
353    @property
354    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
355        if self._references is None:
356            self._references = []
357
358            for table in self.tables:
359                self._references.append((table.alias_or_name, table))
360            for expression in itertools.chain(self.derived_tables, self.udtfs):
361                self._references.append(
362                    (
363                        _get_source_alias(expression),
364                        expression if expression.args.get("pivots") else expression.unnest(),
365                    )
366                )
367
368        return self._references
369
370    @property
371    def external_columns(self):
372        """
373        Columns that appear to reference sources in outer scopes.
374
375        Returns:
376            list[exp.Column]: Column instances that don't reference
377                sources in the current scope.
378        """
379        if self._external_columns is None:
380            if isinstance(self.expression, exp.SetOperation):
381                left, right = self.union_scopes
382                self._external_columns = left.external_columns + right.external_columns
383            else:
384                self._external_columns = [
385                    c
386                    for c in self.columns
387                    if c.table not in self.selected_sources
388                    and c.table not in self.semi_or_anti_join_tables
389                ]
390
391        return self._external_columns
392
393    @property
394    def unqualified_columns(self):
395        """
396        Unqualified columns in the current scope.
397
398        Returns:
399             list[exp.Column]: Unqualified columns
400        """
401        return [c for c in self.columns if not c.table]
402
403    @property
404    def join_hints(self):
405        """
406        Hints that exist in the scope that reference tables
407
408        Returns:
409            list[exp.JoinHint]: Join hints that are referenced within the scope
410        """
411        if self._join_hints is None:
412            return []
413        return self._join_hints
414
415    @property
416    def pivots(self):
417        if not self._pivots:
418            self._pivots = [
419                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
420            ]
421
422        return self._pivots
423
424    @property
425    def semi_or_anti_join_tables(self):
426        return self._semi_anti_join_tables or set()
427
428    def source_columns(self, source_name):
429        """
430        Get all columns in the current scope for a particular source.
431
432        Args:
433            source_name (str): Name of the source
434        Returns:
435            list[exp.Column]: Column instances that reference `source_name`
436        """
437        return [column for column in self.columns if column.table == source_name]
438
439    @property
440    def is_subquery(self):
441        """Determine if this scope is a subquery"""
442        return self.scope_type == ScopeType.SUBQUERY
443
444    @property
445    def is_derived_table(self):
446        """Determine if this scope is a derived table"""
447        return self.scope_type == ScopeType.DERIVED_TABLE
448
449    @property
450    def is_union(self):
451        """Determine if this scope is a union"""
452        return self.scope_type == ScopeType.UNION
453
454    @property
455    def is_cte(self):
456        """Determine if this scope is a common table expression"""
457        return self.scope_type == ScopeType.CTE
458
459    @property
460    def is_root(self):
461        """Determine if this is the root scope"""
462        return self.scope_type == ScopeType.ROOT
463
464    @property
465    def is_udtf(self):
466        """Determine if this scope is a UDTF (User Defined Table Function)"""
467        return self.scope_type == ScopeType.UDTF
468
469    @property
470    def is_correlated_subquery(self):
471        """Determine if this scope is a correlated subquery"""
472        return bool(self.can_be_correlated and self.external_columns)
473
474    def rename_source(self, old_name, new_name):
475        """Rename a source in this scope"""
476        columns = self.sources.pop(old_name or "", [])
477        self.sources[new_name] = columns
478
479    def add_source(self, name, source):
480        """Add a source to this scope"""
481        self.sources[name] = source
482        self.clear_cache()
483
484    def remove_source(self, name):
485        """Remove a source from this scope"""
486        self.sources.pop(name, None)
487        self.clear_cache()
488
489    def __repr__(self):
490        return f"Scope<{self.expression.sql()}>"
491
492    def traverse(self):
493        """
494        Traverse the scope tree from this node.
495
496        Yields:
497            Scope: scope instances in depth-first-search post-order
498        """
499        stack = [self]
500        result = []
501        while stack:
502            scope = stack.pop()
503            result.append(scope)
504            stack.extend(
505                itertools.chain(
506                    scope.cte_scopes,
507                    scope.union_scopes,
508                    scope.table_scopes,
509                    scope.subquery_scopes,
510                )
511            )
512
513        yield from reversed(result)
514
515    def ref_count(self):
516        """
517        Count the number of times each scope in this tree is referenced.
518
519        Returns:
520            dict[int, int]: Mapping of Scope instance ID to reference count
521        """
522        scope_ref_count = defaultdict(lambda: 0)
523
524        for scope in self.traverse():
525            for _, source in scope.selected_sources.values():
526                scope_ref_count[id(source)] += 1
527
528            for name in scope._semi_anti_join_tables:
529                # semi/anti join sources are not actually selected but we still need to
530                # increment their ref count to avoid them being optimized away
531                if name in scope.sources:
532                    scope_ref_count[id(scope.sources[name])] += 1
533
534        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.SetOperation): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • cte_sources (dict[str, Scope]): Sources from CTES
  • outer_columns (list[str]): If this is a derived table or CTE, and the outer query defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_columns
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_columns=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None, cte_sources=None, can_be_correlated=None)
60    def __init__(
61        self,
62        expression,
63        sources=None,
64        outer_columns=None,
65        parent=None,
66        scope_type=ScopeType.ROOT,
67        lateral_sources=None,
68        cte_sources=None,
69        can_be_correlated=None,
70    ):
71        self.expression = expression
72        self.sources = sources or {}
73        self.lateral_sources = lateral_sources or {}
74        self.cte_sources = cte_sources or {}
75        self.sources.update(self.lateral_sources)
76        self.sources.update(self.cte_sources)
77        self.outer_columns = outer_columns or []
78        self.parent = parent
79        self.scope_type = scope_type
80        self.subquery_scopes = []
81        self.derived_table_scopes = []
82        self.table_scopes = []
83        self.cte_scopes = []
84        self.union_scopes = []
85        self.udtf_scopes = []
86        self.can_be_correlated = can_be_correlated
87        self.clear_cache()
expression
sources
lateral_sources
cte_sources
outer_columns
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
can_be_correlated
def clear_cache(self):
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._join_hints = None
103        self._pivots = None
104        self._references = None
105        self._semi_anti_join_tables = None
def branch( self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs):
107    def branch(
108        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
109    ):
110        """Branch from the current scope to a new, inner scope"""
111        return Scope(
112            expression=expression.unnest(),
113            sources=sources.copy() if sources else None,
114            parent=self,
115            scope_type=scope_type,
116            cte_sources={**self.cte_sources, **(cte_sources or {})},
117            lateral_sources=lateral_sources.copy() if lateral_sources else None,
118            can_be_correlated=self.can_be_correlated
119            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
120            **kwargs,
121        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
171    def walk(self, bfs=True, prune=None):
172        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
174    def find(self, *expression_types, bfs=True):
175        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
177    def find_all(self, *expression_types, bfs=True):
178        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
180    def replace(self, old, new):
181        """
182        Replace `old` with `new`.
183
184        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
185
186        Args:
187            old (exp.Expression): old node
188            new (exp.Expression): new node
189        """
190        old.replace(new)
191        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables
193    @property
194    def tables(self):
195        """
196        List of tables in this scope.
197
198        Returns:
199            list[exp.Table]: tables
200        """
201        self._ensure_collected()
202        return self._tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes
204    @property
205    def ctes(self):
206        """
207        List of CTEs in this scope.
208
209        Returns:
210            list[exp.CTE]: ctes
211        """
212        self._ensure_collected()
213        return self._ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables
215    @property
216    def derived_tables(self):
217        """
218        List of derived tables in this scope.
219
220        For example:
221            SELECT * FROM (SELECT ...) <- that's a derived table
222
223        Returns:
224            list[exp.Subquery]: derived tables
225        """
226        self._ensure_collected()
227        return self._derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs
229    @property
230    def udtfs(self):
231        """
232        List of "User Defined Tabular Functions" in this scope.
233
234        Returns:
235            list[exp.UDTF]: UDTFs
236        """
237        self._ensure_collected()
238        return self._udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries
240    @property
241    def subqueries(self):
242        """
243        List of subqueries in this scope.
244
245        For example:
246            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
247
248        Returns:
249            list[exp.Select | exp.SetOperation]: subqueries
250        """
251        self._ensure_collected()
252        return self._subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Select | exp.SetOperation]: subqueries

254    @property
255    def stars(self) -> t.List[exp.Column | exp.Dot]:
256        """
257        List of star expressions (columns or dots) in this scope.
258        """
259        self._ensure_collected()
260        return self._stars

List of star expressions (columns or dots) in this scope.

columns
262    @property
263    def columns(self):
264        """
265        List of columns in this scope.
266
267        Returns:
268            list[exp.Column]: Column instances in this scope, plus any
269                Columns that reference this scope from correlated subqueries.
270        """
271        if self._columns is None:
272            self._ensure_collected()
273            columns = self._raw_columns
274
275            external_columns = [
276                column
277                for scope in itertools.chain(
278                    self.subquery_scopes,
279                    self.udtf_scopes,
280                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
281                )
282                for column in scope.external_columns
283            ]
284
285            named_selects = set(self.expression.named_selects)
286
287            self._columns = []
288            for column in columns + external_columns:
289                ancestor = column.find_ancestor(
290                    exp.Select,
291                    exp.Qualify,
292                    exp.Order,
293                    exp.Having,
294                    exp.Hint,
295                    exp.Table,
296                    exp.Star,
297                    exp.Distinct,
298                )
299                if (
300                    not ancestor
301                    or column.table
302                    or isinstance(ancestor, exp.Select)
303                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
304                    or (
305                        isinstance(ancestor, (exp.Order, exp.Distinct))
306                        and (
307                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
308                            or not isinstance(ancestor.parent, exp.Select)
309                            or column.name not in named_selects
310                        )
311                    )
312                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
313                ):
314                    self._columns.append(column)
315
316        return self._columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

table_columns
318    @property
319    def table_columns(self):
320        if self._table_columns is None:
321            self._ensure_collected()
322
323        return self._table_columns
selected_sources
325    @property
326    def selected_sources(self):
327        """
328        Mapping of nodes and sources that are actually selected from in this scope.
329
330        That is, all tables in a schema are selectable at any point. But a
331        table only becomes a selected source if it's included in a FROM or JOIN clause.
332
333        Returns:
334            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
335        """
336        if self._selected_sources is None:
337            result = {}
338
339            for name, node in self.references:
340                if name in self._semi_anti_join_tables:
341                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
342                    # selected source
343                    continue
344
345                if name in result:
346                    raise OptimizeError(f"Alias already used: {name}")
347                if name in self.sources:
348                    result[name] = (node, self.sources[name])
349
350            self._selected_sources = result
351        return self._selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
353    @property
354    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
355        if self._references is None:
356            self._references = []
357
358            for table in self.tables:
359                self._references.append((table.alias_or_name, table))
360            for expression in itertools.chain(self.derived_tables, self.udtfs):
361                self._references.append(
362                    (
363                        _get_source_alias(expression),
364                        expression if expression.args.get("pivots") else expression.unnest(),
365                    )
366                )
367
368        return self._references
external_columns
370    @property
371    def external_columns(self):
372        """
373        Columns that appear to reference sources in outer scopes.
374
375        Returns:
376            list[exp.Column]: Column instances that don't reference
377                sources in the current scope.
378        """
379        if self._external_columns is None:
380            if isinstance(self.expression, exp.SetOperation):
381                left, right = self.union_scopes
382                self._external_columns = left.external_columns + right.external_columns
383            else:
384                self._external_columns = [
385                    c
386                    for c in self.columns
387                    if c.table not in self.selected_sources
388                    and c.table not in self.semi_or_anti_join_tables
389                ]
390
391        return self._external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns
393    @property
394    def unqualified_columns(self):
395        """
396        Unqualified columns in the current scope.
397
398        Returns:
399             list[exp.Column]: Unqualified columns
400        """
401        return [c for c in self.columns if not c.table]

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints
403    @property
404    def join_hints(self):
405        """
406        Hints that exist in the scope that reference tables
407
408        Returns:
409            list[exp.JoinHint]: Join hints that are referenced within the scope
410        """
411        if self._join_hints is None:
412            return []
413        return self._join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
415    @property
416    def pivots(self):
417        if not self._pivots:
418            self._pivots = [
419                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
420            ]
421
422        return self._pivots
semi_or_anti_join_tables
424    @property
425    def semi_or_anti_join_tables(self):
426        return self._semi_anti_join_tables or set()
def source_columns(self, source_name):
428    def source_columns(self, source_name):
429        """
430        Get all columns in the current scope for a particular source.
431
432        Args:
433            source_name (str): Name of the source
434        Returns:
435            list[exp.Column]: Column instances that reference `source_name`
436        """
437        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery
439    @property
440    def is_subquery(self):
441        """Determine if this scope is a subquery"""
442        return self.scope_type == ScopeType.SUBQUERY

Determine if this scope is a subquery

is_derived_table
444    @property
445    def is_derived_table(self):
446        """Determine if this scope is a derived table"""
447        return self.scope_type == ScopeType.DERIVED_TABLE

Determine if this scope is a derived table

is_union
449    @property
450    def is_union(self):
451        """Determine if this scope is a union"""
452        return self.scope_type == ScopeType.UNION

Determine if this scope is a union

is_cte
454    @property
455    def is_cte(self):
456        """Determine if this scope is a common table expression"""
457        return self.scope_type == ScopeType.CTE

Determine if this scope is a common table expression

is_root
459    @property
460    def is_root(self):
461        """Determine if this is the root scope"""
462        return self.scope_type == ScopeType.ROOT

Determine if this is the root scope

is_udtf
464    @property
465    def is_udtf(self):
466        """Determine if this scope is a UDTF (User Defined Table Function)"""
467        return self.scope_type == ScopeType.UDTF

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery
469    @property
470    def is_correlated_subquery(self):
471        """Determine if this scope is a correlated subquery"""
472        return bool(self.can_be_correlated and self.external_columns)

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
474    def rename_source(self, old_name, new_name):
475        """Rename a source in this scope"""
476        columns = self.sources.pop(old_name or "", [])
477        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
479    def add_source(self, name, source):
480        """Add a source to this scope"""
481        self.sources[name] = source
482        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
484    def remove_source(self, name):
485        """Remove a source from this scope"""
486        self.sources.pop(name, None)
487        self.clear_cache()

Remove a source from this scope

def traverse(self):
492    def traverse(self):
493        """
494        Traverse the scope tree from this node.
495
496        Yields:
497            Scope: scope instances in depth-first-search post-order
498        """
499        stack = [self]
500        result = []
501        while stack:
502            scope = stack.pop()
503            result.append(scope)
504            stack.extend(
505                itertools.chain(
506                    scope.cte_scopes,
507                    scope.union_scopes,
508                    scope.table_scopes,
509                    scope.subquery_scopes,
510                )
511            )
512
513        yield from reversed(result)

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
515    def ref_count(self):
516        """
517        Count the number of times each scope in this tree is referenced.
518
519        Returns:
520            dict[int, int]: Mapping of Scope instance ID to reference count
521        """
522        scope_ref_count = defaultdict(lambda: 0)
523
524        for scope in self.traverse():
525            for _, source in scope.selected_sources.values():
526                scope_ref_count[id(source)] += 1
527
528            for name in scope._semi_anti_join_tables:
529                # semi/anti join sources are not actually selected but we still need to
530                # increment their ref count to avoid them being optimized away
531                if name in scope.sources:
532                    scope_ref_count[id(scope.sources[name])] += 1
533
534        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
537def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
538    """
539    Traverse an expression by its "scopes".
540
541    "Scope" represents the current context of a Select statement.
542
543    This is helpful for optimizing queries, where we need more information than
544    the expression tree itself. For example, we might care about the source
545    names within a subquery. Returns a list because a generator could result in
546    incomplete properties which is confusing.
547
548    Examples:
549        >>> import sqlglot
550        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
551        >>> scopes = traverse_scope(expression)
552        >>> scopes[0].expression.sql(), list(scopes[0].sources)
553        ('SELECT a FROM x', ['x'])
554        >>> scopes[1].expression.sql(), list(scopes[1].sources)
555        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
556
557    Args:
558        expression: Expression to traverse
559
560    Returns:
561        A list of the created scope instances
562    """
563    if isinstance(expression, TRAVERSABLES):
564        return list(_traverse_scope(Scope(expression)))
565    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression: Expression to traverse
Returns:

A list of the created scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
568def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
569    """
570    Build a scope tree.
571
572    Args:
573        expression: Expression to build the scope tree for.
574
575    Returns:
576        The root scope
577    """
578    return seq_get(traverse_scope(expression), -1)

Build a scope tree.

Arguments:
  • expression: Expression to build the scope tree for.
Returns:

The root scope

def walk_in_scope(expression, bfs=True, prune=None):
843def walk_in_scope(expression, bfs=True, prune=None):
844    """
845    Returns a generator object which visits all nodes in the syntrax tree, stopping at
846    nodes that start child scopes.
847
848    Args:
849        expression (exp.Expression):
850        bfs (bool): if set to True the BFS traversal order will be applied,
851            otherwise the DFS traversal will be used instead.
852        prune ((node, parent, arg_key) -> bool): callable that returns True if
853            the generator should stop traversing this branch of the tree.
854
855    Yields:
856        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
857    """
858    # We'll use this variable to pass state into the dfs generator.
859    # Whenever we set it to True, we exclude a subtree from traversal.
860    crossed_scope_boundary = False
861
862    for node in expression.walk(
863        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
864    ):
865        crossed_scope_boundary = False
866
867        yield node
868
869        if node is expression:
870            continue
871
872        if (
873            isinstance(node, exp.CTE)
874            or (
875                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
876                and _is_derived_table(node)
877            )
878            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
879            or isinstance(node, exp.UNWRAPPED_QUERIES)
880        ):
881            crossed_scope_boundary = True
882
883            if isinstance(node, (exp.Subquery, exp.UDTF)):
884                # The following args are not actually in the inner scope, so we should visit them
885                for key in ("joins", "laterals", "pivots"):
886                    for arg in node.args.get(key) or []:
887                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
890def find_all_in_scope(expression, expression_types, bfs=True):
891    """
892    Returns a generator object which visits all nodes in this scope and only yields those that
893    match at least one of the specified expression types.
894
895    This does NOT traverse into subscopes.
896
897    Args:
898        expression (exp.Expression):
899        expression_types (tuple[type]|type): the expression type(s) to match.
900        bfs (bool): True to use breadth-first search, False to use depth-first.
901
902    Yields:
903        exp.Expression: nodes
904    """
905    for expression in walk_in_scope(expression, bfs=bfs):
906        if isinstance(expression, tuple(ensure_collection(expression_types))):
907            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
910def find_in_scope(expression, expression_types, bfs=True):
911    """
912    Returns the first node in this scope which matches at least one of the specified types.
913
914    This does NOT traverse into subscopes.
915
916    Args:
917        expression (exp.Expression):
918        expression_types (tuple[type]|type): the expression type(s) to match.
919        bfs (bool): True to use breadth-first search, False to use depth-first.
920
921    Returns:
922        exp.Expression: the node which matches the criteria or None if no node matching
923        the criteria was found.
924    """
925    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.