Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name, seq_get
 12
 13logger = logging.getLogger("sqlglot")
 14
 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
 16
 17
 18class ScopeType(Enum):
 19    ROOT = auto()
 20    SUBQUERY = auto()
 21    DERIVED_TABLE = auto()
 22    CTE = auto()
 23    UNION = auto()
 24    UDTF = auto()
 25
 26
 27class Scope:
 28    """
 29    Selection scope.
 30
 31    Attributes:
 32        expression (exp.Select|exp.SetOperation): Root expression of this scope
 33        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 34            a Table expression or another Scope instance. For example:
 35                SELECT * FROM x                     {"x": Table(this="x")}
 36                SELECT * FROM x AS y                {"y": Table(this="x")}
 37                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 38        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 39            For example:
 40                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 41            The LATERAL VIEW EXPLODE gets x as a source.
 42        cte_sources (dict[str, Scope]): Sources from CTES
 43        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 44            defines a column list for the alias of this scope, this is that list of columns.
 45            For example:
 46                SELECT * FROM (SELECT ...) AS y(col1, col2)
 47            The inner query would have `["col1", "col2"]` for its `outer_columns`
 48        parent (Scope): Parent scope
 49        scope_type (ScopeType): Type of this scope, relative to it's parent
 50        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 51        cte_scopes (list[Scope]): List of all child scopes for CTEs
 52        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 53        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 54        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 55        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 56            a list of the left and right child scopes.
 57    """
 58
 59    def __init__(
 60        self,
 61        expression,
 62        sources=None,
 63        outer_columns=None,
 64        parent=None,
 65        scope_type=ScopeType.ROOT,
 66        lateral_sources=None,
 67        cte_sources=None,
 68        can_be_correlated=None,
 69    ):
 70        self.expression = expression
 71        self.sources = sources or {}
 72        self.lateral_sources = lateral_sources or {}
 73        self.cte_sources = cte_sources or {}
 74        self.sources.update(self.lateral_sources)
 75        self.sources.update(self.cte_sources)
 76        self.outer_columns = outer_columns or []
 77        self.parent = parent
 78        self.scope_type = scope_type
 79        self.subquery_scopes = []
 80        self.derived_table_scopes = []
 81        self.table_scopes = []
 82        self.cte_scopes = []
 83        self.union_scopes = []
 84        self.udtf_scopes = []
 85        self.can_be_correlated = can_be_correlated
 86        self.clear_cache()
 87
 88    def clear_cache(self):
 89        self._collected = False
 90        self._raw_columns = None
 91        self._stars = None
 92        self._derived_tables = None
 93        self._udtfs = None
 94        self._tables = None
 95        self._ctes = None
 96        self._subqueries = None
 97        self._selected_sources = None
 98        self._columns = None
 99        self._external_columns = None
100        self._join_hints = None
101        self._pivots = None
102        self._references = None
103        self._semi_anti_join_tables = None
104
105    def branch(
106        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
107    ):
108        """Branch from the current scope to a new, inner scope"""
109        return Scope(
110            expression=expression.unnest(),
111            sources=sources.copy() if sources else None,
112            parent=self,
113            scope_type=scope_type,
114            cte_sources={**self.cte_sources, **(cte_sources or {})},
115            lateral_sources=lateral_sources.copy() if lateral_sources else None,
116            can_be_correlated=self.can_be_correlated
117            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
118            **kwargs,
119        )
120
121    def _collect(self):
122        self._tables = []
123        self._ctes = []
124        self._subqueries = []
125        self._derived_tables = []
126        self._udtfs = []
127        self._raw_columns = []
128        self._stars = []
129        self._join_hints = []
130        self._semi_anti_join_tables = set()
131
132        for node in self.walk(bfs=False):
133            if node is self.expression:
134                continue
135
136            if isinstance(node, exp.Dot) and node.is_star:
137                self._stars.append(node)
138            elif isinstance(node, exp.Column):
139                if isinstance(node.this, exp.Star):
140                    self._stars.append(node)
141                else:
142                    self._raw_columns.append(node)
143            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
144                parent = node.parent
145                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
146                    self._semi_anti_join_tables.add(node.alias_or_name)
147
148                self._tables.append(node)
149            elif isinstance(node, exp.JoinHint):
150                self._join_hints.append(node)
151            elif isinstance(node, exp.UDTF):
152                self._udtfs.append(node)
153            elif isinstance(node, exp.CTE):
154                self._ctes.append(node)
155            elif _is_derived_table(node) and _is_from_or_join(node):
156                self._derived_tables.append(node)
157            elif isinstance(node, exp.UNWRAPPED_QUERIES):
158                self._subqueries.append(node)
159
160        self._collected = True
161
162    def _ensure_collected(self):
163        if not self._collected:
164            self._collect()
165
166    def walk(self, bfs=True, prune=None):
167        return walk_in_scope(self.expression, bfs=bfs, prune=None)
168
169    def find(self, *expression_types, bfs=True):
170        return find_in_scope(self.expression, expression_types, bfs=bfs)
171
172    def find_all(self, *expression_types, bfs=True):
173        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
174
175    def replace(self, old, new):
176        """
177        Replace `old` with `new`.
178
179        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
180
181        Args:
182            old (exp.Expression): old node
183            new (exp.Expression): new node
184        """
185        old.replace(new)
186        self.clear_cache()
187
188    @property
189    def tables(self):
190        """
191        List of tables in this scope.
192
193        Returns:
194            list[exp.Table]: tables
195        """
196        self._ensure_collected()
197        return self._tables
198
199    @property
200    def ctes(self):
201        """
202        List of CTEs in this scope.
203
204        Returns:
205            list[exp.CTE]: ctes
206        """
207        self._ensure_collected()
208        return self._ctes
209
210    @property
211    def derived_tables(self):
212        """
213        List of derived tables in this scope.
214
215        For example:
216            SELECT * FROM (SELECT ...) <- that's a derived table
217
218        Returns:
219            list[exp.Subquery]: derived tables
220        """
221        self._ensure_collected()
222        return self._derived_tables
223
224    @property
225    def udtfs(self):
226        """
227        List of "User Defined Tabular Functions" in this scope.
228
229        Returns:
230            list[exp.UDTF]: UDTFs
231        """
232        self._ensure_collected()
233        return self._udtfs
234
235    @property
236    def subqueries(self):
237        """
238        List of subqueries in this scope.
239
240        For example:
241            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
242
243        Returns:
244            list[exp.Select | exp.SetOperation]: subqueries
245        """
246        self._ensure_collected()
247        return self._subqueries
248
249    @property
250    def stars(self) -> t.List[exp.Column | exp.Dot]:
251        """
252        List of star expressions (columns or dots) in this scope.
253        """
254        self._ensure_collected()
255        return self._stars
256
257    @property
258    def columns(self):
259        """
260        List of columns in this scope.
261
262        Returns:
263            list[exp.Column]: Column instances in this scope, plus any
264                Columns that reference this scope from correlated subqueries.
265        """
266        if self._columns is None:
267            self._ensure_collected()
268            columns = self._raw_columns
269
270            external_columns = [
271                column
272                for scope in itertools.chain(
273                    self.subquery_scopes,
274                    self.udtf_scopes,
275                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
276                )
277                for column in scope.external_columns
278            ]
279
280            named_selects = set(self.expression.named_selects)
281
282            self._columns = []
283            for column in columns + external_columns:
284                ancestor = column.find_ancestor(
285                    exp.Select,
286                    exp.Qualify,
287                    exp.Order,
288                    exp.Having,
289                    exp.Hint,
290                    exp.Table,
291                    exp.Star,
292                    exp.Distinct,
293                )
294                if (
295                    not ancestor
296                    or column.table
297                    or isinstance(ancestor, exp.Select)
298                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
299                    or (
300                        isinstance(ancestor, (exp.Order, exp.Distinct))
301                        and (
302                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
303                            or column.name not in named_selects
304                        )
305                    )
306                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
307                ):
308                    self._columns.append(column)
309
310        return self._columns
311
312    @property
313    def selected_sources(self):
314        """
315        Mapping of nodes and sources that are actually selected from in this scope.
316
317        That is, all tables in a schema are selectable at any point. But a
318        table only becomes a selected source if it's included in a FROM or JOIN clause.
319
320        Returns:
321            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
322        """
323        if self._selected_sources is None:
324            result = {}
325
326            for name, node in self.references:
327                if name in self._semi_anti_join_tables:
328                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
329                    # selected source
330                    continue
331
332                if name in result:
333                    raise OptimizeError(f"Alias already used: {name}")
334                if name in self.sources:
335                    result[name] = (node, self.sources[name])
336
337            self._selected_sources = result
338        return self._selected_sources
339
340    @property
341    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
342        if self._references is None:
343            self._references = []
344
345            for table in self.tables:
346                self._references.append((table.alias_or_name, table))
347            for expression in itertools.chain(self.derived_tables, self.udtfs):
348                self._references.append(
349                    (
350                        expression.alias,
351                        expression if expression.args.get("pivots") else expression.unnest(),
352                    )
353                )
354
355        return self._references
356
357    @property
358    def external_columns(self):
359        """
360        Columns that appear to reference sources in outer scopes.
361
362        Returns:
363            list[exp.Column]: Column instances that don't reference
364                sources in the current scope.
365        """
366        if self._external_columns is None:
367            if isinstance(self.expression, exp.SetOperation):
368                left, right = self.union_scopes
369                self._external_columns = left.external_columns + right.external_columns
370            else:
371                self._external_columns = [
372                    c
373                    for c in self.columns
374                    if c.table not in self.selected_sources
375                    and c.table not in self.semi_or_anti_join_tables
376                ]
377
378        return self._external_columns
379
380    @property
381    def unqualified_columns(self):
382        """
383        Unqualified columns in the current scope.
384
385        Returns:
386             list[exp.Column]: Unqualified columns
387        """
388        return [c for c in self.columns if not c.table]
389
390    @property
391    def join_hints(self):
392        """
393        Hints that exist in the scope that reference tables
394
395        Returns:
396            list[exp.JoinHint]: Join hints that are referenced within the scope
397        """
398        if self._join_hints is None:
399            return []
400        return self._join_hints
401
402    @property
403    def pivots(self):
404        if not self._pivots:
405            self._pivots = [
406                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
407            ]
408
409        return self._pivots
410
411    @property
412    def semi_or_anti_join_tables(self):
413        return self._semi_anti_join_tables or set()
414
415    def source_columns(self, source_name):
416        """
417        Get all columns in the current scope for a particular source.
418
419        Args:
420            source_name (str): Name of the source
421        Returns:
422            list[exp.Column]: Column instances that reference `source_name`
423        """
424        return [column for column in self.columns if column.table == source_name]
425
426    @property
427    def is_subquery(self):
428        """Determine if this scope is a subquery"""
429        return self.scope_type == ScopeType.SUBQUERY
430
431    @property
432    def is_derived_table(self):
433        """Determine if this scope is a derived table"""
434        return self.scope_type == ScopeType.DERIVED_TABLE
435
436    @property
437    def is_union(self):
438        """Determine if this scope is a union"""
439        return self.scope_type == ScopeType.UNION
440
441    @property
442    def is_cte(self):
443        """Determine if this scope is a common table expression"""
444        return self.scope_type == ScopeType.CTE
445
446    @property
447    def is_root(self):
448        """Determine if this is the root scope"""
449        return self.scope_type == ScopeType.ROOT
450
451    @property
452    def is_udtf(self):
453        """Determine if this scope is a UDTF (User Defined Table Function)"""
454        return self.scope_type == ScopeType.UDTF
455
456    @property
457    def is_correlated_subquery(self):
458        """Determine if this scope is a correlated subquery"""
459        return bool(self.can_be_correlated and self.external_columns)
460
461    def rename_source(self, old_name, new_name):
462        """Rename a source in this scope"""
463        columns = self.sources.pop(old_name or "", [])
464        self.sources[new_name] = columns
465
466    def add_source(self, name, source):
467        """Add a source to this scope"""
468        self.sources[name] = source
469        self.clear_cache()
470
471    def remove_source(self, name):
472        """Remove a source from this scope"""
473        self.sources.pop(name, None)
474        self.clear_cache()
475
476    def __repr__(self):
477        return f"Scope<{self.expression.sql()}>"
478
479    def traverse(self):
480        """
481        Traverse the scope tree from this node.
482
483        Yields:
484            Scope: scope instances in depth-first-search post-order
485        """
486        stack = [self]
487        result = []
488        while stack:
489            scope = stack.pop()
490            result.append(scope)
491            stack.extend(
492                itertools.chain(
493                    scope.cte_scopes,
494                    scope.union_scopes,
495                    scope.table_scopes,
496                    scope.subquery_scopes,
497                )
498            )
499
500        yield from reversed(result)
501
502    def ref_count(self):
503        """
504        Count the number of times each scope in this tree is referenced.
505
506        Returns:
507            dict[int, int]: Mapping of Scope instance ID to reference count
508        """
509        scope_ref_count = defaultdict(lambda: 0)
510
511        for scope in self.traverse():
512            for _, source in scope.selected_sources.values():
513                scope_ref_count[id(source)] += 1
514
515        return scope_ref_count
516
517
518def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
519    """
520    Traverse an expression by its "scopes".
521
522    "Scope" represents the current context of a Select statement.
523
524    This is helpful for optimizing queries, where we need more information than
525    the expression tree itself. For example, we might care about the source
526    names within a subquery. Returns a list because a generator could result in
527    incomplete properties which is confusing.
528
529    Examples:
530        >>> import sqlglot
531        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
532        >>> scopes = traverse_scope(expression)
533        >>> scopes[0].expression.sql(), list(scopes[0].sources)
534        ('SELECT a FROM x', ['x'])
535        >>> scopes[1].expression.sql(), list(scopes[1].sources)
536        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
537
538    Args:
539        expression: Expression to traverse
540
541    Returns:
542        A list of the created scope instances
543    """
544    if isinstance(expression, TRAVERSABLES):
545        return list(_traverse_scope(Scope(expression)))
546    return []
547
548
549def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
550    """
551    Build a scope tree.
552
553    Args:
554        expression: Expression to build the scope tree for.
555
556    Returns:
557        The root scope
558    """
559    return seq_get(traverse_scope(expression), -1)
560
561
562def _traverse_scope(scope):
563    expression = scope.expression
564
565    if isinstance(expression, exp.Select):
566        yield from _traverse_select(scope)
567    elif isinstance(expression, exp.SetOperation):
568        yield from _traverse_ctes(scope)
569        yield from _traverse_union(scope)
570        return
571    elif isinstance(expression, exp.Subquery):
572        if scope.is_root:
573            yield from _traverse_select(scope)
574        else:
575            yield from _traverse_subqueries(scope)
576    elif isinstance(expression, exp.Table):
577        yield from _traverse_tables(scope)
578    elif isinstance(expression, exp.UDTF):
579        yield from _traverse_udtfs(scope)
580    elif isinstance(expression, exp.DDL):
581        if isinstance(expression.expression, exp.Query):
582            yield from _traverse_ctes(scope)
583            yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
584        return
585    elif isinstance(expression, exp.DML):
586        yield from _traverse_ctes(scope)
587        for query in find_all_in_scope(expression, exp.Query):
588            # This check ensures we don't yield the CTE/nested queries twice
589            if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
590                yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
591        return
592    else:
593        logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
594        return
595
596    yield scope
597
598
599def _traverse_select(scope):
600    yield from _traverse_ctes(scope)
601    yield from _traverse_tables(scope)
602    yield from _traverse_subqueries(scope)
603
604
605def _traverse_union(scope):
606    prev_scope = None
607    union_scope_stack = [scope]
608    expression_stack = [scope.expression.right, scope.expression.left]
609
610    while expression_stack:
611        expression = expression_stack.pop()
612        union_scope = union_scope_stack[-1]
613
614        new_scope = union_scope.branch(
615            expression,
616            outer_columns=union_scope.outer_columns,
617            scope_type=ScopeType.UNION,
618        )
619
620        if isinstance(expression, exp.SetOperation):
621            yield from _traverse_ctes(new_scope)
622
623            union_scope_stack.append(new_scope)
624            expression_stack.extend([expression.right, expression.left])
625            continue
626
627        for scope in _traverse_scope(new_scope):
628            yield scope
629
630        if prev_scope:
631            union_scope_stack.pop()
632            union_scope.union_scopes = [prev_scope, scope]
633            prev_scope = union_scope
634
635            yield union_scope
636        else:
637            prev_scope = scope
638
639
640def _traverse_ctes(scope):
641    sources = {}
642
643    for cte in scope.ctes:
644        cte_name = cte.alias
645
646        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
647        # thus the recursive scope is the first section of the union.
648        with_ = scope.expression.args.get("with")
649        if with_ and with_.recursive:
650            union = cte.this
651
652            if isinstance(union, exp.SetOperation):
653                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
654
655        child_scope = None
656
657        for child_scope in _traverse_scope(
658            scope.branch(
659                cte.this,
660                cte_sources=sources,
661                outer_columns=cte.alias_column_names,
662                scope_type=ScopeType.CTE,
663            )
664        ):
665            yield child_scope
666
667        # append the final child_scope yielded
668        if child_scope:
669            sources[cte_name] = child_scope
670            scope.cte_scopes.append(child_scope)
671
672    scope.sources.update(sources)
673    scope.cte_sources.update(sources)
674
675
676def _is_derived_table(expression: exp.Subquery) -> bool:
677    """
678    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
679    as it doesn't introduce a new scope. If an alias is present, it shadows all names
680    under the Subquery, so that's one exception to this rule.
681    """
682    return isinstance(expression, exp.Subquery) and bool(
683        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
684    )
685
686
687def _is_from_or_join(expression: exp.Expression) -> bool:
688    """
689    Determine if `expression` is the FROM or JOIN clause of a SELECT statement.
690    """
691    parent = expression.parent
692
693    # Subqueries can be arbitrarily nested
694    while isinstance(parent, exp.Subquery):
695        parent = parent.parent
696
697    return isinstance(parent, (exp.From, exp.Join))
698
699
700def _traverse_tables(scope):
701    sources = {}
702
703    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
704    expressions = []
705    from_ = scope.expression.args.get("from")
706    if from_:
707        expressions.append(from_.this)
708
709    for join in scope.expression.args.get("joins") or []:
710        expressions.append(join.this)
711
712    if isinstance(scope.expression, exp.Table):
713        expressions.append(scope.expression)
714
715    expressions.extend(scope.expression.args.get("laterals") or [])
716
717    for expression in expressions:
718        if isinstance(expression, exp.Final):
719            expression = expression.this
720        if isinstance(expression, exp.Table):
721            table_name = expression.name
722            source_name = expression.alias_or_name
723
724            if table_name in scope.sources and not expression.db:
725                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
726                # it is pivoted, because then we get back a new table and hence a new source.
727                pivots = expression.args.get("pivots")
728                if pivots:
729                    sources[pivots[0].alias] = expression
730                else:
731                    sources[source_name] = scope.sources[table_name]
732            elif source_name in sources:
733                sources[find_new_name(sources, table_name)] = expression
734            else:
735                sources[source_name] = expression
736
737            # Make sure to not include the joins twice
738            if expression is not scope.expression:
739                expressions.extend(join.this for join in expression.args.get("joins") or [])
740
741            continue
742
743        if not isinstance(expression, exp.DerivedTable):
744            continue
745
746        if isinstance(expression, exp.UDTF):
747            lateral_sources = sources
748            scope_type = ScopeType.UDTF
749            scopes = scope.udtf_scopes
750        elif _is_derived_table(expression):
751            lateral_sources = None
752            scope_type = ScopeType.DERIVED_TABLE
753            scopes = scope.derived_table_scopes
754            expressions.extend(join.this for join in expression.args.get("joins") or [])
755        else:
756            # Makes sure we check for possible sources in nested table constructs
757            expressions.append(expression.this)
758            expressions.extend(join.this for join in expression.args.get("joins") or [])
759            continue
760
761        for child_scope in _traverse_scope(
762            scope.branch(
763                expression,
764                lateral_sources=lateral_sources,
765                outer_columns=expression.alias_column_names,
766                scope_type=scope_type,
767            )
768        ):
769            yield child_scope
770
771            # Tables without aliases will be set as ""
772            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
773            # Until then, this means that only a single, unaliased derived table is allowed (rather,
774            # the latest one wins.
775            sources[expression.alias] = child_scope
776
777        # append the final child_scope yielded
778        scopes.append(child_scope)
779        scope.table_scopes.append(child_scope)
780
781    scope.sources.update(sources)
782
783
784def _traverse_subqueries(scope):
785    for subquery in scope.subqueries:
786        top = None
787        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
788            yield child_scope
789            top = child_scope
790        scope.subquery_scopes.append(top)
791
792
793def _traverse_udtfs(scope):
794    if isinstance(scope.expression, exp.Unnest):
795        expressions = scope.expression.expressions
796    elif isinstance(scope.expression, exp.Lateral):
797        expressions = [scope.expression.this]
798    else:
799        expressions = []
800
801    sources = {}
802    for expression in expressions:
803        if _is_derived_table(expression):
804            top = None
805            for child_scope in _traverse_scope(
806                scope.branch(
807                    expression,
808                    scope_type=ScopeType.SUBQUERY,
809                    outer_columns=expression.alias_column_names,
810                )
811            ):
812                yield child_scope
813                top = child_scope
814                sources[expression.alias] = child_scope
815
816            scope.subquery_scopes.append(top)
817
818    scope.sources.update(sources)
819
820
821def walk_in_scope(expression, bfs=True, prune=None):
822    """
823    Returns a generator object which visits all nodes in the syntrax tree, stopping at
824    nodes that start child scopes.
825
826    Args:
827        expression (exp.Expression):
828        bfs (bool): if set to True the BFS traversal order will be applied,
829            otherwise the DFS traversal will be used instead.
830        prune ((node, parent, arg_key) -> bool): callable that returns True if
831            the generator should stop traversing this branch of the tree.
832
833    Yields:
834        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
835    """
836    # We'll use this variable to pass state into the dfs generator.
837    # Whenever we set it to True, we exclude a subtree from traversal.
838    crossed_scope_boundary = False
839
840    for node in expression.walk(
841        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
842    ):
843        crossed_scope_boundary = False
844
845        yield node
846
847        if node is expression:
848            continue
849        if (
850            isinstance(node, exp.CTE)
851            or (
852                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
853                and (_is_derived_table(node) or isinstance(node, exp.UDTF))
854            )
855            or isinstance(node, exp.UNWRAPPED_QUERIES)
856        ):
857            crossed_scope_boundary = True
858
859            if isinstance(node, (exp.Subquery, exp.UDTF)):
860                # The following args are not actually in the inner scope, so we should visit them
861                for key in ("joins", "laterals", "pivots"):
862                    for arg in node.args.get(key) or []:
863                        yield from walk_in_scope(arg, bfs=bfs)
864
865
866def find_all_in_scope(expression, expression_types, bfs=True):
867    """
868    Returns a generator object which visits all nodes in this scope and only yields those that
869    match at least one of the specified expression types.
870
871    This does NOT traverse into subscopes.
872
873    Args:
874        expression (exp.Expression):
875        expression_types (tuple[type]|type): the expression type(s) to match.
876        bfs (bool): True to use breadth-first search, False to use depth-first.
877
878    Yields:
879        exp.Expression: nodes
880    """
881    for expression in walk_in_scope(expression, bfs=bfs):
882        if isinstance(expression, tuple(ensure_collection(expression_types))):
883            yield expression
884
885
886def find_in_scope(expression, expression_types, bfs=True):
887    """
888    Returns the first node in this scope which matches at least one of the specified types.
889
890    This does NOT traverse into subscopes.
891
892    Args:
893        expression (exp.Expression):
894        expression_types (tuple[type]|type): the expression type(s) to match.
895        bfs (bool): True to use breadth-first search, False to use depth-first.
896
897    Returns:
898        exp.Expression: the node which matches the criteria or None if no node matching
899        the criteria was found.
900    """
901    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
logger = <Logger sqlglot (WARNING)>
TRAVERSABLES = (<class 'sqlglot.expressions.Query'>, <class 'sqlglot.expressions.DDL'>, <class 'sqlglot.expressions.DML'>)
class ScopeType(enum.Enum):
19class ScopeType(Enum):
20    ROOT = auto()
21    SUBQUERY = auto()
22    DERIVED_TABLE = auto()
23    CTE = auto()
24    UNION = auto()
25    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
class Scope:
 28class Scope:
 29    """
 30    Selection scope.
 31
 32    Attributes:
 33        expression (exp.Select|exp.SetOperation): Root expression of this scope
 34        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 35            a Table expression or another Scope instance. For example:
 36                SELECT * FROM x                     {"x": Table(this="x")}
 37                SELECT * FROM x AS y                {"y": Table(this="x")}
 38                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 39        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 40            For example:
 41                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 42            The LATERAL VIEW EXPLODE gets x as a source.
 43        cte_sources (dict[str, Scope]): Sources from CTES
 44        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 45            defines a column list for the alias of this scope, this is that list of columns.
 46            For example:
 47                SELECT * FROM (SELECT ...) AS y(col1, col2)
 48            The inner query would have `["col1", "col2"]` for its `outer_columns`
 49        parent (Scope): Parent scope
 50        scope_type (ScopeType): Type of this scope, relative to it's parent
 51        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 52        cte_scopes (list[Scope]): List of all child scopes for CTEs
 53        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 54        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 55        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 56        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 57            a list of the left and right child scopes.
 58    """
 59
 60    def __init__(
 61        self,
 62        expression,
 63        sources=None,
 64        outer_columns=None,
 65        parent=None,
 66        scope_type=ScopeType.ROOT,
 67        lateral_sources=None,
 68        cte_sources=None,
 69        can_be_correlated=None,
 70    ):
 71        self.expression = expression
 72        self.sources = sources or {}
 73        self.lateral_sources = lateral_sources or {}
 74        self.cte_sources = cte_sources or {}
 75        self.sources.update(self.lateral_sources)
 76        self.sources.update(self.cte_sources)
 77        self.outer_columns = outer_columns or []
 78        self.parent = parent
 79        self.scope_type = scope_type
 80        self.subquery_scopes = []
 81        self.derived_table_scopes = []
 82        self.table_scopes = []
 83        self.cte_scopes = []
 84        self.union_scopes = []
 85        self.udtf_scopes = []
 86        self.can_be_correlated = can_be_correlated
 87        self.clear_cache()
 88
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._join_hints = None
102        self._pivots = None
103        self._references = None
104        self._semi_anti_join_tables = None
105
106    def branch(
107        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
108    ):
109        """Branch from the current scope to a new, inner scope"""
110        return Scope(
111            expression=expression.unnest(),
112            sources=sources.copy() if sources else None,
113            parent=self,
114            scope_type=scope_type,
115            cte_sources={**self.cte_sources, **(cte_sources or {})},
116            lateral_sources=lateral_sources.copy() if lateral_sources else None,
117            can_be_correlated=self.can_be_correlated
118            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
119            **kwargs,
120        )
121
122    def _collect(self):
123        self._tables = []
124        self._ctes = []
125        self._subqueries = []
126        self._derived_tables = []
127        self._udtfs = []
128        self._raw_columns = []
129        self._stars = []
130        self._join_hints = []
131        self._semi_anti_join_tables = set()
132
133        for node in self.walk(bfs=False):
134            if node is self.expression:
135                continue
136
137            if isinstance(node, exp.Dot) and node.is_star:
138                self._stars.append(node)
139            elif isinstance(node, exp.Column):
140                if isinstance(node.this, exp.Star):
141                    self._stars.append(node)
142                else:
143                    self._raw_columns.append(node)
144            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
145                parent = node.parent
146                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
147                    self._semi_anti_join_tables.add(node.alias_or_name)
148
149                self._tables.append(node)
150            elif isinstance(node, exp.JoinHint):
151                self._join_hints.append(node)
152            elif isinstance(node, exp.UDTF):
153                self._udtfs.append(node)
154            elif isinstance(node, exp.CTE):
155                self._ctes.append(node)
156            elif _is_derived_table(node) and _is_from_or_join(node):
157                self._derived_tables.append(node)
158            elif isinstance(node, exp.UNWRAPPED_QUERIES):
159                self._subqueries.append(node)
160
161        self._collected = True
162
163    def _ensure_collected(self):
164        if not self._collected:
165            self._collect()
166
167    def walk(self, bfs=True, prune=None):
168        return walk_in_scope(self.expression, bfs=bfs, prune=None)
169
170    def find(self, *expression_types, bfs=True):
171        return find_in_scope(self.expression, expression_types, bfs=bfs)
172
173    def find_all(self, *expression_types, bfs=True):
174        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
175
176    def replace(self, old, new):
177        """
178        Replace `old` with `new`.
179
180        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
181
182        Args:
183            old (exp.Expression): old node
184            new (exp.Expression): new node
185        """
186        old.replace(new)
187        self.clear_cache()
188
189    @property
190    def tables(self):
191        """
192        List of tables in this scope.
193
194        Returns:
195            list[exp.Table]: tables
196        """
197        self._ensure_collected()
198        return self._tables
199
200    @property
201    def ctes(self):
202        """
203        List of CTEs in this scope.
204
205        Returns:
206            list[exp.CTE]: ctes
207        """
208        self._ensure_collected()
209        return self._ctes
210
211    @property
212    def derived_tables(self):
213        """
214        List of derived tables in this scope.
215
216        For example:
217            SELECT * FROM (SELECT ...) <- that's a derived table
218
219        Returns:
220            list[exp.Subquery]: derived tables
221        """
222        self._ensure_collected()
223        return self._derived_tables
224
225    @property
226    def udtfs(self):
227        """
228        List of "User Defined Tabular Functions" in this scope.
229
230        Returns:
231            list[exp.UDTF]: UDTFs
232        """
233        self._ensure_collected()
234        return self._udtfs
235
236    @property
237    def subqueries(self):
238        """
239        List of subqueries in this scope.
240
241        For example:
242            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
243
244        Returns:
245            list[exp.Select | exp.SetOperation]: subqueries
246        """
247        self._ensure_collected()
248        return self._subqueries
249
250    @property
251    def stars(self) -> t.List[exp.Column | exp.Dot]:
252        """
253        List of star expressions (columns or dots) in this scope.
254        """
255        self._ensure_collected()
256        return self._stars
257
258    @property
259    def columns(self):
260        """
261        List of columns in this scope.
262
263        Returns:
264            list[exp.Column]: Column instances in this scope, plus any
265                Columns that reference this scope from correlated subqueries.
266        """
267        if self._columns is None:
268            self._ensure_collected()
269            columns = self._raw_columns
270
271            external_columns = [
272                column
273                for scope in itertools.chain(
274                    self.subquery_scopes,
275                    self.udtf_scopes,
276                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
277                )
278                for column in scope.external_columns
279            ]
280
281            named_selects = set(self.expression.named_selects)
282
283            self._columns = []
284            for column in columns + external_columns:
285                ancestor = column.find_ancestor(
286                    exp.Select,
287                    exp.Qualify,
288                    exp.Order,
289                    exp.Having,
290                    exp.Hint,
291                    exp.Table,
292                    exp.Star,
293                    exp.Distinct,
294                )
295                if (
296                    not ancestor
297                    or column.table
298                    or isinstance(ancestor, exp.Select)
299                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
300                    or (
301                        isinstance(ancestor, (exp.Order, exp.Distinct))
302                        and (
303                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
304                            or column.name not in named_selects
305                        )
306                    )
307                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
308                ):
309                    self._columns.append(column)
310
311        return self._columns
312
313    @property
314    def selected_sources(self):
315        """
316        Mapping of nodes and sources that are actually selected from in this scope.
317
318        That is, all tables in a schema are selectable at any point. But a
319        table only becomes a selected source if it's included in a FROM or JOIN clause.
320
321        Returns:
322            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
323        """
324        if self._selected_sources is None:
325            result = {}
326
327            for name, node in self.references:
328                if name in self._semi_anti_join_tables:
329                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
330                    # selected source
331                    continue
332
333                if name in result:
334                    raise OptimizeError(f"Alias already used: {name}")
335                if name in self.sources:
336                    result[name] = (node, self.sources[name])
337
338            self._selected_sources = result
339        return self._selected_sources
340
341    @property
342    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
343        if self._references is None:
344            self._references = []
345
346            for table in self.tables:
347                self._references.append((table.alias_or_name, table))
348            for expression in itertools.chain(self.derived_tables, self.udtfs):
349                self._references.append(
350                    (
351                        expression.alias,
352                        expression if expression.args.get("pivots") else expression.unnest(),
353                    )
354                )
355
356        return self._references
357
358    @property
359    def external_columns(self):
360        """
361        Columns that appear to reference sources in outer scopes.
362
363        Returns:
364            list[exp.Column]: Column instances that don't reference
365                sources in the current scope.
366        """
367        if self._external_columns is None:
368            if isinstance(self.expression, exp.SetOperation):
369                left, right = self.union_scopes
370                self._external_columns = left.external_columns + right.external_columns
371            else:
372                self._external_columns = [
373                    c
374                    for c in self.columns
375                    if c.table not in self.selected_sources
376                    and c.table not in self.semi_or_anti_join_tables
377                ]
378
379        return self._external_columns
380
381    @property
382    def unqualified_columns(self):
383        """
384        Unqualified columns in the current scope.
385
386        Returns:
387             list[exp.Column]: Unqualified columns
388        """
389        return [c for c in self.columns if not c.table]
390
391    @property
392    def join_hints(self):
393        """
394        Hints that exist in the scope that reference tables
395
396        Returns:
397            list[exp.JoinHint]: Join hints that are referenced within the scope
398        """
399        if self._join_hints is None:
400            return []
401        return self._join_hints
402
403    @property
404    def pivots(self):
405        if not self._pivots:
406            self._pivots = [
407                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
408            ]
409
410        return self._pivots
411
412    @property
413    def semi_or_anti_join_tables(self):
414        return self._semi_anti_join_tables or set()
415
416    def source_columns(self, source_name):
417        """
418        Get all columns in the current scope for a particular source.
419
420        Args:
421            source_name (str): Name of the source
422        Returns:
423            list[exp.Column]: Column instances that reference `source_name`
424        """
425        return [column for column in self.columns if column.table == source_name]
426
427    @property
428    def is_subquery(self):
429        """Determine if this scope is a subquery"""
430        return self.scope_type == ScopeType.SUBQUERY
431
432    @property
433    def is_derived_table(self):
434        """Determine if this scope is a derived table"""
435        return self.scope_type == ScopeType.DERIVED_TABLE
436
437    @property
438    def is_union(self):
439        """Determine if this scope is a union"""
440        return self.scope_type == ScopeType.UNION
441
442    @property
443    def is_cte(self):
444        """Determine if this scope is a common table expression"""
445        return self.scope_type == ScopeType.CTE
446
447    @property
448    def is_root(self):
449        """Determine if this is the root scope"""
450        return self.scope_type == ScopeType.ROOT
451
452    @property
453    def is_udtf(self):
454        """Determine if this scope is a UDTF (User Defined Table Function)"""
455        return self.scope_type == ScopeType.UDTF
456
457    @property
458    def is_correlated_subquery(self):
459        """Determine if this scope is a correlated subquery"""
460        return bool(self.can_be_correlated and self.external_columns)
461
462    def rename_source(self, old_name, new_name):
463        """Rename a source in this scope"""
464        columns = self.sources.pop(old_name or "", [])
465        self.sources[new_name] = columns
466
467    def add_source(self, name, source):
468        """Add a source to this scope"""
469        self.sources[name] = source
470        self.clear_cache()
471
472    def remove_source(self, name):
473        """Remove a source from this scope"""
474        self.sources.pop(name, None)
475        self.clear_cache()
476
477    def __repr__(self):
478        return f"Scope<{self.expression.sql()}>"
479
480    def traverse(self):
481        """
482        Traverse the scope tree from this node.
483
484        Yields:
485            Scope: scope instances in depth-first-search post-order
486        """
487        stack = [self]
488        result = []
489        while stack:
490            scope = stack.pop()
491            result.append(scope)
492            stack.extend(
493                itertools.chain(
494                    scope.cte_scopes,
495                    scope.union_scopes,
496                    scope.table_scopes,
497                    scope.subquery_scopes,
498                )
499            )
500
501        yield from reversed(result)
502
503    def ref_count(self):
504        """
505        Count the number of times each scope in this tree is referenced.
506
507        Returns:
508            dict[int, int]: Mapping of Scope instance ID to reference count
509        """
510        scope_ref_count = defaultdict(lambda: 0)
511
512        for scope in self.traverse():
513            for _, source in scope.selected_sources.values():
514                scope_ref_count[id(source)] += 1
515
516        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.SetOperation): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • cte_sources (dict[str, Scope]): Sources from CTES
  • outer_columns (list[str]): If this is a derived table or CTE, and the outer query defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_columns
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_columns=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None, cte_sources=None, can_be_correlated=None)
60    def __init__(
61        self,
62        expression,
63        sources=None,
64        outer_columns=None,
65        parent=None,
66        scope_type=ScopeType.ROOT,
67        lateral_sources=None,
68        cte_sources=None,
69        can_be_correlated=None,
70    ):
71        self.expression = expression
72        self.sources = sources or {}
73        self.lateral_sources = lateral_sources or {}
74        self.cte_sources = cte_sources or {}
75        self.sources.update(self.lateral_sources)
76        self.sources.update(self.cte_sources)
77        self.outer_columns = outer_columns or []
78        self.parent = parent
79        self.scope_type = scope_type
80        self.subquery_scopes = []
81        self.derived_table_scopes = []
82        self.table_scopes = []
83        self.cte_scopes = []
84        self.union_scopes = []
85        self.udtf_scopes = []
86        self.can_be_correlated = can_be_correlated
87        self.clear_cache()
expression
sources
lateral_sources
cte_sources
outer_columns
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
can_be_correlated
def clear_cache(self):
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._join_hints = None
102        self._pivots = None
103        self._references = None
104        self._semi_anti_join_tables = None
def branch( self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs):
106    def branch(
107        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
108    ):
109        """Branch from the current scope to a new, inner scope"""
110        return Scope(
111            expression=expression.unnest(),
112            sources=sources.copy() if sources else None,
113            parent=self,
114            scope_type=scope_type,
115            cte_sources={**self.cte_sources, **(cte_sources or {})},
116            lateral_sources=lateral_sources.copy() if lateral_sources else None,
117            can_be_correlated=self.can_be_correlated
118            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
119            **kwargs,
120        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
167    def walk(self, bfs=True, prune=None):
168        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
170    def find(self, *expression_types, bfs=True):
171        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
173    def find_all(self, *expression_types, bfs=True):
174        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
176    def replace(self, old, new):
177        """
178        Replace `old` with `new`.
179
180        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
181
182        Args:
183            old (exp.Expression): old node
184            new (exp.Expression): new node
185        """
186        old.replace(new)
187        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables
189    @property
190    def tables(self):
191        """
192        List of tables in this scope.
193
194        Returns:
195            list[exp.Table]: tables
196        """
197        self._ensure_collected()
198        return self._tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes
200    @property
201    def ctes(self):
202        """
203        List of CTEs in this scope.
204
205        Returns:
206            list[exp.CTE]: ctes
207        """
208        self._ensure_collected()
209        return self._ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables
211    @property
212    def derived_tables(self):
213        """
214        List of derived tables in this scope.
215
216        For example:
217            SELECT * FROM (SELECT ...) <- that's a derived table
218
219        Returns:
220            list[exp.Subquery]: derived tables
221        """
222        self._ensure_collected()
223        return self._derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs
225    @property
226    def udtfs(self):
227        """
228        List of "User Defined Tabular Functions" in this scope.
229
230        Returns:
231            list[exp.UDTF]: UDTFs
232        """
233        self._ensure_collected()
234        return self._udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries
236    @property
237    def subqueries(self):
238        """
239        List of subqueries in this scope.
240
241        For example:
242            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
243
244        Returns:
245            list[exp.Select | exp.SetOperation]: subqueries
246        """
247        self._ensure_collected()
248        return self._subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Select | exp.SetOperation]: subqueries

250    @property
251    def stars(self) -> t.List[exp.Column | exp.Dot]:
252        """
253        List of star expressions (columns or dots) in this scope.
254        """
255        self._ensure_collected()
256        return self._stars

List of star expressions (columns or dots) in this scope.

columns
258    @property
259    def columns(self):
260        """
261        List of columns in this scope.
262
263        Returns:
264            list[exp.Column]: Column instances in this scope, plus any
265                Columns that reference this scope from correlated subqueries.
266        """
267        if self._columns is None:
268            self._ensure_collected()
269            columns = self._raw_columns
270
271            external_columns = [
272                column
273                for scope in itertools.chain(
274                    self.subquery_scopes,
275                    self.udtf_scopes,
276                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
277                )
278                for column in scope.external_columns
279            ]
280
281            named_selects = set(self.expression.named_selects)
282
283            self._columns = []
284            for column in columns + external_columns:
285                ancestor = column.find_ancestor(
286                    exp.Select,
287                    exp.Qualify,
288                    exp.Order,
289                    exp.Having,
290                    exp.Hint,
291                    exp.Table,
292                    exp.Star,
293                    exp.Distinct,
294                )
295                if (
296                    not ancestor
297                    or column.table
298                    or isinstance(ancestor, exp.Select)
299                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
300                    or (
301                        isinstance(ancestor, (exp.Order, exp.Distinct))
302                        and (
303                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
304                            or column.name not in named_selects
305                        )
306                    )
307                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
308                ):
309                    self._columns.append(column)
310
311        return self._columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources
313    @property
314    def selected_sources(self):
315        """
316        Mapping of nodes and sources that are actually selected from in this scope.
317
318        That is, all tables in a schema are selectable at any point. But a
319        table only becomes a selected source if it's included in a FROM or JOIN clause.
320
321        Returns:
322            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
323        """
324        if self._selected_sources is None:
325            result = {}
326
327            for name, node in self.references:
328                if name in self._semi_anti_join_tables:
329                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
330                    # selected source
331                    continue
332
333                if name in result:
334                    raise OptimizeError(f"Alias already used: {name}")
335                if name in self.sources:
336                    result[name] = (node, self.sources[name])
337
338            self._selected_sources = result
339        return self._selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
341    @property
342    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
343        if self._references is None:
344            self._references = []
345
346            for table in self.tables:
347                self._references.append((table.alias_or_name, table))
348            for expression in itertools.chain(self.derived_tables, self.udtfs):
349                self._references.append(
350                    (
351                        expression.alias,
352                        expression if expression.args.get("pivots") else expression.unnest(),
353                    )
354                )
355
356        return self._references
external_columns
358    @property
359    def external_columns(self):
360        """
361        Columns that appear to reference sources in outer scopes.
362
363        Returns:
364            list[exp.Column]: Column instances that don't reference
365                sources in the current scope.
366        """
367        if self._external_columns is None:
368            if isinstance(self.expression, exp.SetOperation):
369                left, right = self.union_scopes
370                self._external_columns = left.external_columns + right.external_columns
371            else:
372                self._external_columns = [
373                    c
374                    for c in self.columns
375                    if c.table not in self.selected_sources
376                    and c.table not in self.semi_or_anti_join_tables
377                ]
378
379        return self._external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns
381    @property
382    def unqualified_columns(self):
383        """
384        Unqualified columns in the current scope.
385
386        Returns:
387             list[exp.Column]: Unqualified columns
388        """
389        return [c for c in self.columns if not c.table]

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints
391    @property
392    def join_hints(self):
393        """
394        Hints that exist in the scope that reference tables
395
396        Returns:
397            list[exp.JoinHint]: Join hints that are referenced within the scope
398        """
399        if self._join_hints is None:
400            return []
401        return self._join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
403    @property
404    def pivots(self):
405        if not self._pivots:
406            self._pivots = [
407                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
408            ]
409
410        return self._pivots
semi_or_anti_join_tables
412    @property
413    def semi_or_anti_join_tables(self):
414        return self._semi_anti_join_tables or set()
def source_columns(self, source_name):
416    def source_columns(self, source_name):
417        """
418        Get all columns in the current scope for a particular source.
419
420        Args:
421            source_name (str): Name of the source
422        Returns:
423            list[exp.Column]: Column instances that reference `source_name`
424        """
425        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery
427    @property
428    def is_subquery(self):
429        """Determine if this scope is a subquery"""
430        return self.scope_type == ScopeType.SUBQUERY

Determine if this scope is a subquery

is_derived_table
432    @property
433    def is_derived_table(self):
434        """Determine if this scope is a derived table"""
435        return self.scope_type == ScopeType.DERIVED_TABLE

Determine if this scope is a derived table

is_union
437    @property
438    def is_union(self):
439        """Determine if this scope is a union"""
440        return self.scope_type == ScopeType.UNION

Determine if this scope is a union

is_cte
442    @property
443    def is_cte(self):
444        """Determine if this scope is a common table expression"""
445        return self.scope_type == ScopeType.CTE

Determine if this scope is a common table expression

is_root
447    @property
448    def is_root(self):
449        """Determine if this is the root scope"""
450        return self.scope_type == ScopeType.ROOT

Determine if this is the root scope

is_udtf
452    @property
453    def is_udtf(self):
454        """Determine if this scope is a UDTF (User Defined Table Function)"""
455        return self.scope_type == ScopeType.UDTF

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery
457    @property
458    def is_correlated_subquery(self):
459        """Determine if this scope is a correlated subquery"""
460        return bool(self.can_be_correlated and self.external_columns)

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
462    def rename_source(self, old_name, new_name):
463        """Rename a source in this scope"""
464        columns = self.sources.pop(old_name or "", [])
465        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
467    def add_source(self, name, source):
468        """Add a source to this scope"""
469        self.sources[name] = source
470        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
472    def remove_source(self, name):
473        """Remove a source from this scope"""
474        self.sources.pop(name, None)
475        self.clear_cache()

Remove a source from this scope

def traverse(self):
480    def traverse(self):
481        """
482        Traverse the scope tree from this node.
483
484        Yields:
485            Scope: scope instances in depth-first-search post-order
486        """
487        stack = [self]
488        result = []
489        while stack:
490            scope = stack.pop()
491            result.append(scope)
492            stack.extend(
493                itertools.chain(
494                    scope.cte_scopes,
495                    scope.union_scopes,
496                    scope.table_scopes,
497                    scope.subquery_scopes,
498                )
499            )
500
501        yield from reversed(result)

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
503    def ref_count(self):
504        """
505        Count the number of times each scope in this tree is referenced.
506
507        Returns:
508            dict[int, int]: Mapping of Scope instance ID to reference count
509        """
510        scope_ref_count = defaultdict(lambda: 0)
511
512        for scope in self.traverse():
513            for _, source in scope.selected_sources.values():
514                scope_ref_count[id(source)] += 1
515
516        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
519def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
520    """
521    Traverse an expression by its "scopes".
522
523    "Scope" represents the current context of a Select statement.
524
525    This is helpful for optimizing queries, where we need more information than
526    the expression tree itself. For example, we might care about the source
527    names within a subquery. Returns a list because a generator could result in
528    incomplete properties which is confusing.
529
530    Examples:
531        >>> import sqlglot
532        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
533        >>> scopes = traverse_scope(expression)
534        >>> scopes[0].expression.sql(), list(scopes[0].sources)
535        ('SELECT a FROM x', ['x'])
536        >>> scopes[1].expression.sql(), list(scopes[1].sources)
537        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
538
539    Args:
540        expression: Expression to traverse
541
542    Returns:
543        A list of the created scope instances
544    """
545    if isinstance(expression, TRAVERSABLES):
546        return list(_traverse_scope(Scope(expression)))
547    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression: Expression to traverse
Returns:

A list of the created scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
550def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
551    """
552    Build a scope tree.
553
554    Args:
555        expression: Expression to build the scope tree for.
556
557    Returns:
558        The root scope
559    """
560    return seq_get(traverse_scope(expression), -1)

Build a scope tree.

Arguments:
  • expression: Expression to build the scope tree for.
Returns:

The root scope

def walk_in_scope(expression, bfs=True, prune=None):
822def walk_in_scope(expression, bfs=True, prune=None):
823    """
824    Returns a generator object which visits all nodes in the syntrax tree, stopping at
825    nodes that start child scopes.
826
827    Args:
828        expression (exp.Expression):
829        bfs (bool): if set to True the BFS traversal order will be applied,
830            otherwise the DFS traversal will be used instead.
831        prune ((node, parent, arg_key) -> bool): callable that returns True if
832            the generator should stop traversing this branch of the tree.
833
834    Yields:
835        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
836    """
837    # We'll use this variable to pass state into the dfs generator.
838    # Whenever we set it to True, we exclude a subtree from traversal.
839    crossed_scope_boundary = False
840
841    for node in expression.walk(
842        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
843    ):
844        crossed_scope_boundary = False
845
846        yield node
847
848        if node is expression:
849            continue
850        if (
851            isinstance(node, exp.CTE)
852            or (
853                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
854                and (_is_derived_table(node) or isinstance(node, exp.UDTF))
855            )
856            or isinstance(node, exp.UNWRAPPED_QUERIES)
857        ):
858            crossed_scope_boundary = True
859
860            if isinstance(node, (exp.Subquery, exp.UDTF)):
861                # The following args are not actually in the inner scope, so we should visit them
862                for key in ("joins", "laterals", "pivots"):
863                    for arg in node.args.get(key) or []:
864                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
867def find_all_in_scope(expression, expression_types, bfs=True):
868    """
869    Returns a generator object which visits all nodes in this scope and only yields those that
870    match at least one of the specified expression types.
871
872    This does NOT traverse into subscopes.
873
874    Args:
875        expression (exp.Expression):
876        expression_types (tuple[type]|type): the expression type(s) to match.
877        bfs (bool): True to use breadth-first search, False to use depth-first.
878
879    Yields:
880        exp.Expression: nodes
881    """
882    for expression in walk_in_scope(expression, bfs=bfs):
883        if isinstance(expression, tuple(ensure_collection(expression_types))):
884            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
887def find_in_scope(expression, expression_types, bfs=True):
888    """
889    Returns the first node in this scope which matches at least one of the specified types.
890
891    This does NOT traverse into subscopes.
892
893    Args:
894        expression (exp.Expression):
895        expression_types (tuple[type]|type): the expression type(s) to match.
896        bfs (bool): True to use breadth-first search, False to use depth-first.
897
898    Returns:
899        exp.Expression: the node which matches the criteria or None if no node matching
900        the criteria was found.
901    """
902    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.