Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name, seq_get
 12
 13logger = logging.getLogger("sqlglot")
 14
 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
 16
 17
 18class ScopeType(Enum):
 19    ROOT = auto()
 20    SUBQUERY = auto()
 21    DERIVED_TABLE = auto()
 22    CTE = auto()
 23    UNION = auto()
 24    UDTF = auto()
 25
 26
 27class Scope:
 28    """
 29    Selection scope.
 30
 31    Attributes:
 32        expression (exp.Select|exp.SetOperation): Root expression of this scope
 33        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 34            a Table expression or another Scope instance. For example:
 35                SELECT * FROM x                     {"x": Table(this="x")}
 36                SELECT * FROM x AS y                {"y": Table(this="x")}
 37                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 38        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 39            For example:
 40                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 41            The LATERAL VIEW EXPLODE gets x as a source.
 42        cte_sources (dict[str, Scope]): Sources from CTES
 43        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 44            defines a column list for the alias of this scope, this is that list of columns.
 45            For example:
 46                SELECT * FROM (SELECT ...) AS y(col1, col2)
 47            The inner query would have `["col1", "col2"]` for its `outer_columns`
 48        parent (Scope): Parent scope
 49        scope_type (ScopeType): Type of this scope, relative to it's parent
 50        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 51        cte_scopes (list[Scope]): List of all child scopes for CTEs
 52        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 53        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 54        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 55        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 56            a list of the left and right child scopes.
 57    """
 58
 59    def __init__(
 60        self,
 61        expression,
 62        sources=None,
 63        outer_columns=None,
 64        parent=None,
 65        scope_type=ScopeType.ROOT,
 66        lateral_sources=None,
 67        cte_sources=None,
 68        can_be_correlated=None,
 69    ):
 70        self.expression = expression
 71        self.sources = sources or {}
 72        self.lateral_sources = lateral_sources or {}
 73        self.cte_sources = cte_sources or {}
 74        self.sources.update(self.lateral_sources)
 75        self.sources.update(self.cte_sources)
 76        self.outer_columns = outer_columns or []
 77        self.parent = parent
 78        self.scope_type = scope_type
 79        self.subquery_scopes = []
 80        self.derived_table_scopes = []
 81        self.table_scopes = []
 82        self.cte_scopes = []
 83        self.union_scopes = []
 84        self.udtf_scopes = []
 85        self.can_be_correlated = can_be_correlated
 86        self.clear_cache()
 87
 88    def clear_cache(self):
 89        self._collected = False
 90        self._raw_columns = None
 91        self._stars = None
 92        self._derived_tables = None
 93        self._udtfs = None
 94        self._tables = None
 95        self._ctes = None
 96        self._subqueries = None
 97        self._selected_sources = None
 98        self._columns = None
 99        self._external_columns = None
100        self._join_hints = None
101        self._pivots = None
102        self._references = None
103        self._semi_anti_join_tables = None
104
105    def branch(
106        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
107    ):
108        """Branch from the current scope to a new, inner scope"""
109        return Scope(
110            expression=expression.unnest(),
111            sources=sources.copy() if sources else None,
112            parent=self,
113            scope_type=scope_type,
114            cte_sources={**self.cte_sources, **(cte_sources or {})},
115            lateral_sources=lateral_sources.copy() if lateral_sources else None,
116            can_be_correlated=self.can_be_correlated
117            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
118            **kwargs,
119        )
120
121    def _collect(self):
122        self._tables = []
123        self._ctes = []
124        self._subqueries = []
125        self._derived_tables = []
126        self._udtfs = []
127        self._raw_columns = []
128        self._stars = []
129        self._join_hints = []
130        self._semi_anti_join_tables = set()
131
132        for node in self.walk(bfs=False):
133            if node is self.expression:
134                continue
135
136            if isinstance(node, exp.Dot) and node.is_star:
137                self._stars.append(node)
138            elif isinstance(node, exp.Column):
139                if isinstance(node.this, exp.Star):
140                    self._stars.append(node)
141                else:
142                    self._raw_columns.append(node)
143            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
144                parent = node.parent
145                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
146                    self._semi_anti_join_tables.add(node.alias_or_name)
147
148                self._tables.append(node)
149            elif isinstance(node, exp.JoinHint):
150                self._join_hints.append(node)
151            elif isinstance(node, exp.UDTF):
152                self._udtfs.append(node)
153            elif isinstance(node, exp.CTE):
154                self._ctes.append(node)
155            elif _is_derived_table(node) and _is_from_or_join(node):
156                self._derived_tables.append(node)
157            elif isinstance(node, exp.UNWRAPPED_QUERIES):
158                self._subqueries.append(node)
159
160        self._collected = True
161
162    def _ensure_collected(self):
163        if not self._collected:
164            self._collect()
165
166    def walk(self, bfs=True, prune=None):
167        return walk_in_scope(self.expression, bfs=bfs, prune=None)
168
169    def find(self, *expression_types, bfs=True):
170        return find_in_scope(self.expression, expression_types, bfs=bfs)
171
172    def find_all(self, *expression_types, bfs=True):
173        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
174
175    def replace(self, old, new):
176        """
177        Replace `old` with `new`.
178
179        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
180
181        Args:
182            old (exp.Expression): old node
183            new (exp.Expression): new node
184        """
185        old.replace(new)
186        self.clear_cache()
187
188    @property
189    def tables(self):
190        """
191        List of tables in this scope.
192
193        Returns:
194            list[exp.Table]: tables
195        """
196        self._ensure_collected()
197        return self._tables
198
199    @property
200    def ctes(self):
201        """
202        List of CTEs in this scope.
203
204        Returns:
205            list[exp.CTE]: ctes
206        """
207        self._ensure_collected()
208        return self._ctes
209
210    @property
211    def derived_tables(self):
212        """
213        List of derived tables in this scope.
214
215        For example:
216            SELECT * FROM (SELECT ...) <- that's a derived table
217
218        Returns:
219            list[exp.Subquery]: derived tables
220        """
221        self._ensure_collected()
222        return self._derived_tables
223
224    @property
225    def udtfs(self):
226        """
227        List of "User Defined Tabular Functions" in this scope.
228
229        Returns:
230            list[exp.UDTF]: UDTFs
231        """
232        self._ensure_collected()
233        return self._udtfs
234
235    @property
236    def subqueries(self):
237        """
238        List of subqueries in this scope.
239
240        For example:
241            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
242
243        Returns:
244            list[exp.Select | exp.SetOperation]: subqueries
245        """
246        self._ensure_collected()
247        return self._subqueries
248
249    @property
250    def stars(self) -> t.List[exp.Column | exp.Dot]:
251        """
252        List of star expressions (columns or dots) in this scope.
253        """
254        self._ensure_collected()
255        return self._stars
256
257    @property
258    def columns(self):
259        """
260        List of columns in this scope.
261
262        Returns:
263            list[exp.Column]: Column instances in this scope, plus any
264                Columns that reference this scope from correlated subqueries.
265        """
266        if self._columns is None:
267            self._ensure_collected()
268            columns = self._raw_columns
269
270            external_columns = [
271                column
272                for scope in itertools.chain(
273                    self.subquery_scopes,
274                    self.udtf_scopes,
275                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
276                )
277                for column in scope.external_columns
278            ]
279
280            named_selects = set(self.expression.named_selects)
281
282            self._columns = []
283            for column in columns + external_columns:
284                ancestor = column.find_ancestor(
285                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star
286                )
287                if (
288                    not ancestor
289                    or column.table
290                    or isinstance(ancestor, exp.Select)
291                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
292                    or (
293                        isinstance(ancestor, exp.Order)
294                        and (
295                            isinstance(ancestor.parent, exp.Window)
296                            or column.name not in named_selects
297                        )
298                    )
299                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
300                ):
301                    self._columns.append(column)
302
303        return self._columns
304
305    @property
306    def selected_sources(self):
307        """
308        Mapping of nodes and sources that are actually selected from in this scope.
309
310        That is, all tables in a schema are selectable at any point. But a
311        table only becomes a selected source if it's included in a FROM or JOIN clause.
312
313        Returns:
314            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
315        """
316        if self._selected_sources is None:
317            result = {}
318
319            for name, node in self.references:
320                if name in self._semi_anti_join_tables:
321                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
322                    # selected source
323                    continue
324
325                if name in result:
326                    raise OptimizeError(f"Alias already used: {name}")
327                if name in self.sources:
328                    result[name] = (node, self.sources[name])
329
330            self._selected_sources = result
331        return self._selected_sources
332
333    @property
334    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
335        if self._references is None:
336            self._references = []
337
338            for table in self.tables:
339                self._references.append((table.alias_or_name, table))
340            for expression in itertools.chain(self.derived_tables, self.udtfs):
341                self._references.append(
342                    (
343                        expression.alias,
344                        expression if expression.args.get("pivots") else expression.unnest(),
345                    )
346                )
347
348        return self._references
349
350    @property
351    def external_columns(self):
352        """
353        Columns that appear to reference sources in outer scopes.
354
355        Returns:
356            list[exp.Column]: Column instances that don't reference
357                sources in the current scope.
358        """
359        if self._external_columns is None:
360            if isinstance(self.expression, exp.SetOperation):
361                left, right = self.union_scopes
362                self._external_columns = left.external_columns + right.external_columns
363            else:
364                self._external_columns = [
365                    c
366                    for c in self.columns
367                    if c.table not in self.selected_sources
368                    and c.table not in self.semi_or_anti_join_tables
369                ]
370
371        return self._external_columns
372
373    @property
374    def unqualified_columns(self):
375        """
376        Unqualified columns in the current scope.
377
378        Returns:
379             list[exp.Column]: Unqualified columns
380        """
381        return [c for c in self.columns if not c.table]
382
383    @property
384    def join_hints(self):
385        """
386        Hints that exist in the scope that reference tables
387
388        Returns:
389            list[exp.JoinHint]: Join hints that are referenced within the scope
390        """
391        if self._join_hints is None:
392            return []
393        return self._join_hints
394
395    @property
396    def pivots(self):
397        if not self._pivots:
398            self._pivots = [
399                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
400            ]
401
402        return self._pivots
403
404    @property
405    def semi_or_anti_join_tables(self):
406        return self._semi_anti_join_tables or set()
407
408    def source_columns(self, source_name):
409        """
410        Get all columns in the current scope for a particular source.
411
412        Args:
413            source_name (str): Name of the source
414        Returns:
415            list[exp.Column]: Column instances that reference `source_name`
416        """
417        return [column for column in self.columns if column.table == source_name]
418
419    @property
420    def is_subquery(self):
421        """Determine if this scope is a subquery"""
422        return self.scope_type == ScopeType.SUBQUERY
423
424    @property
425    def is_derived_table(self):
426        """Determine if this scope is a derived table"""
427        return self.scope_type == ScopeType.DERIVED_TABLE
428
429    @property
430    def is_union(self):
431        """Determine if this scope is a union"""
432        return self.scope_type == ScopeType.UNION
433
434    @property
435    def is_cte(self):
436        """Determine if this scope is a common table expression"""
437        return self.scope_type == ScopeType.CTE
438
439    @property
440    def is_root(self):
441        """Determine if this is the root scope"""
442        return self.scope_type == ScopeType.ROOT
443
444    @property
445    def is_udtf(self):
446        """Determine if this scope is a UDTF (User Defined Table Function)"""
447        return self.scope_type == ScopeType.UDTF
448
449    @property
450    def is_correlated_subquery(self):
451        """Determine if this scope is a correlated subquery"""
452        return bool(self.can_be_correlated and self.external_columns)
453
454    def rename_source(self, old_name, new_name):
455        """Rename a source in this scope"""
456        columns = self.sources.pop(old_name or "", [])
457        self.sources[new_name] = columns
458
459    def add_source(self, name, source):
460        """Add a source to this scope"""
461        self.sources[name] = source
462        self.clear_cache()
463
464    def remove_source(self, name):
465        """Remove a source from this scope"""
466        self.sources.pop(name, None)
467        self.clear_cache()
468
469    def __repr__(self):
470        return f"Scope<{self.expression.sql()}>"
471
472    def traverse(self):
473        """
474        Traverse the scope tree from this node.
475
476        Yields:
477            Scope: scope instances in depth-first-search post-order
478        """
479        stack = [self]
480        result = []
481        while stack:
482            scope = stack.pop()
483            result.append(scope)
484            stack.extend(
485                itertools.chain(
486                    scope.cte_scopes,
487                    scope.union_scopes,
488                    scope.table_scopes,
489                    scope.subquery_scopes,
490                )
491            )
492
493        yield from reversed(result)
494
495    def ref_count(self):
496        """
497        Count the number of times each scope in this tree is referenced.
498
499        Returns:
500            dict[int, int]: Mapping of Scope instance ID to reference count
501        """
502        scope_ref_count = defaultdict(lambda: 0)
503
504        for scope in self.traverse():
505            for _, source in scope.selected_sources.values():
506                scope_ref_count[id(source)] += 1
507
508        return scope_ref_count
509
510
511def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
512    """
513    Traverse an expression by its "scopes".
514
515    "Scope" represents the current context of a Select statement.
516
517    This is helpful for optimizing queries, where we need more information than
518    the expression tree itself. For example, we might care about the source
519    names within a subquery. Returns a list because a generator could result in
520    incomplete properties which is confusing.
521
522    Examples:
523        >>> import sqlglot
524        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
525        >>> scopes = traverse_scope(expression)
526        >>> scopes[0].expression.sql(), list(scopes[0].sources)
527        ('SELECT a FROM x', ['x'])
528        >>> scopes[1].expression.sql(), list(scopes[1].sources)
529        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
530
531    Args:
532        expression: Expression to traverse
533
534    Returns:
535        A list of the created scope instances
536    """
537    if isinstance(expression, TRAVERSABLES):
538        return list(_traverse_scope(Scope(expression)))
539    return []
540
541
542def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
543    """
544    Build a scope tree.
545
546    Args:
547        expression: Expression to build the scope tree for.
548
549    Returns:
550        The root scope
551    """
552    return seq_get(traverse_scope(expression), -1)
553
554
555def _traverse_scope(scope):
556    expression = scope.expression
557
558    if isinstance(expression, exp.Select):
559        yield from _traverse_select(scope)
560    elif isinstance(expression, exp.SetOperation):
561        yield from _traverse_ctes(scope)
562        yield from _traverse_union(scope)
563        return
564    elif isinstance(expression, exp.Subquery):
565        if scope.is_root:
566            yield from _traverse_select(scope)
567        else:
568            yield from _traverse_subqueries(scope)
569    elif isinstance(expression, exp.Table):
570        yield from _traverse_tables(scope)
571    elif isinstance(expression, exp.UDTF):
572        yield from _traverse_udtfs(scope)
573    elif isinstance(expression, exp.DDL):
574        if isinstance(expression.expression, exp.Query):
575            yield from _traverse_ctes(scope)
576            yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
577        return
578    elif isinstance(expression, exp.DML):
579        yield from _traverse_ctes(scope)
580        for query in find_all_in_scope(expression, exp.Query):
581            # This check ensures we don't yield the CTE/nested queries twice
582            if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
583                yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
584        return
585    else:
586        logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
587        return
588
589    yield scope
590
591
592def _traverse_select(scope):
593    yield from _traverse_ctes(scope)
594    yield from _traverse_tables(scope)
595    yield from _traverse_subqueries(scope)
596
597
598def _traverse_union(scope):
599    prev_scope = None
600    union_scope_stack = [scope]
601    expression_stack = [scope.expression.right, scope.expression.left]
602
603    while expression_stack:
604        expression = expression_stack.pop()
605        union_scope = union_scope_stack[-1]
606
607        new_scope = union_scope.branch(
608            expression,
609            outer_columns=union_scope.outer_columns,
610            scope_type=ScopeType.UNION,
611        )
612
613        if isinstance(expression, exp.SetOperation):
614            yield from _traverse_ctes(new_scope)
615
616            union_scope_stack.append(new_scope)
617            expression_stack.extend([expression.right, expression.left])
618            continue
619
620        for scope in _traverse_scope(new_scope):
621            yield scope
622
623        if prev_scope:
624            union_scope_stack.pop()
625            union_scope.union_scopes = [prev_scope, scope]
626            prev_scope = union_scope
627
628            yield union_scope
629        else:
630            prev_scope = scope
631
632
633def _traverse_ctes(scope):
634    sources = {}
635
636    for cte in scope.ctes:
637        cte_name = cte.alias
638
639        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
640        # thus the recursive scope is the first section of the union.
641        with_ = scope.expression.args.get("with")
642        if with_ and with_.recursive:
643            union = cte.this
644
645            if isinstance(union, exp.SetOperation):
646                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
647
648        child_scope = None
649
650        for child_scope in _traverse_scope(
651            scope.branch(
652                cte.this,
653                cte_sources=sources,
654                outer_columns=cte.alias_column_names,
655                scope_type=ScopeType.CTE,
656            )
657        ):
658            yield child_scope
659
660        # append the final child_scope yielded
661        if child_scope:
662            sources[cte_name] = child_scope
663            scope.cte_scopes.append(child_scope)
664
665    scope.sources.update(sources)
666    scope.cte_sources.update(sources)
667
668
669def _is_derived_table(expression: exp.Subquery) -> bool:
670    """
671    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
672    as it doesn't introduce a new scope. If an alias is present, it shadows all names
673    under the Subquery, so that's one exception to this rule.
674    """
675    return isinstance(expression, exp.Subquery) and bool(
676        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
677    )
678
679
680def _is_from_or_join(expression: exp.Expression) -> bool:
681    """
682    Determine if `expression` is the FROM or JOIN clause of a SELECT statement.
683    """
684    parent = expression.parent
685
686    # Subqueries can be arbitrarily nested
687    while isinstance(parent, exp.Subquery):
688        parent = parent.parent
689
690    return isinstance(parent, (exp.From, exp.Join))
691
692
693def _traverse_tables(scope):
694    sources = {}
695
696    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
697    expressions = []
698    from_ = scope.expression.args.get("from")
699    if from_:
700        expressions.append(from_.this)
701
702    for join in scope.expression.args.get("joins") or []:
703        expressions.append(join.this)
704
705    if isinstance(scope.expression, exp.Table):
706        expressions.append(scope.expression)
707
708    expressions.extend(scope.expression.args.get("laterals") or [])
709
710    for expression in expressions:
711        if isinstance(expression, exp.Final):
712            expression = expression.this
713        if isinstance(expression, exp.Table):
714            table_name = expression.name
715            source_name = expression.alias_or_name
716
717            if table_name in scope.sources and not expression.db:
718                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
719                # it is pivoted, because then we get back a new table and hence a new source.
720                pivots = expression.args.get("pivots")
721                if pivots:
722                    sources[pivots[0].alias] = expression
723                else:
724                    sources[source_name] = scope.sources[table_name]
725            elif source_name in sources:
726                sources[find_new_name(sources, table_name)] = expression
727            else:
728                sources[source_name] = expression
729
730            # Make sure to not include the joins twice
731            if expression is not scope.expression:
732                expressions.extend(join.this for join in expression.args.get("joins") or [])
733
734            continue
735
736        if not isinstance(expression, exp.DerivedTable):
737            continue
738
739        if isinstance(expression, exp.UDTF):
740            lateral_sources = sources
741            scope_type = ScopeType.UDTF
742            scopes = scope.udtf_scopes
743        elif _is_derived_table(expression):
744            lateral_sources = None
745            scope_type = ScopeType.DERIVED_TABLE
746            scopes = scope.derived_table_scopes
747            expressions.extend(join.this for join in expression.args.get("joins") or [])
748        else:
749            # Makes sure we check for possible sources in nested table constructs
750            expressions.append(expression.this)
751            expressions.extend(join.this for join in expression.args.get("joins") or [])
752            continue
753
754        for child_scope in _traverse_scope(
755            scope.branch(
756                expression,
757                lateral_sources=lateral_sources,
758                outer_columns=expression.alias_column_names,
759                scope_type=scope_type,
760            )
761        ):
762            yield child_scope
763
764            # Tables without aliases will be set as ""
765            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
766            # Until then, this means that only a single, unaliased derived table is allowed (rather,
767            # the latest one wins.
768            sources[expression.alias] = child_scope
769
770        # append the final child_scope yielded
771        scopes.append(child_scope)
772        scope.table_scopes.append(child_scope)
773
774    scope.sources.update(sources)
775
776
777def _traverse_subqueries(scope):
778    for subquery in scope.subqueries:
779        top = None
780        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
781            yield child_scope
782            top = child_scope
783        scope.subquery_scopes.append(top)
784
785
786def _traverse_udtfs(scope):
787    if isinstance(scope.expression, exp.Unnest):
788        expressions = scope.expression.expressions
789    elif isinstance(scope.expression, exp.Lateral):
790        expressions = [scope.expression.this]
791    else:
792        expressions = []
793
794    sources = {}
795    for expression in expressions:
796        if _is_derived_table(expression):
797            top = None
798            for child_scope in _traverse_scope(
799                scope.branch(
800                    expression,
801                    scope_type=ScopeType.SUBQUERY,
802                    outer_columns=expression.alias_column_names,
803                )
804            ):
805                yield child_scope
806                top = child_scope
807                sources[expression.alias] = child_scope
808
809            scope.subquery_scopes.append(top)
810
811    scope.sources.update(sources)
812
813
814def walk_in_scope(expression, bfs=True, prune=None):
815    """
816    Returns a generator object which visits all nodes in the syntrax tree, stopping at
817    nodes that start child scopes.
818
819    Args:
820        expression (exp.Expression):
821        bfs (bool): if set to True the BFS traversal order will be applied,
822            otherwise the DFS traversal will be used instead.
823        prune ((node, parent, arg_key) -> bool): callable that returns True if
824            the generator should stop traversing this branch of the tree.
825
826    Yields:
827        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
828    """
829    # We'll use this variable to pass state into the dfs generator.
830    # Whenever we set it to True, we exclude a subtree from traversal.
831    crossed_scope_boundary = False
832
833    for node in expression.walk(
834        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
835    ):
836        crossed_scope_boundary = False
837
838        yield node
839
840        if node is expression:
841            continue
842        if (
843            isinstance(node, exp.CTE)
844            or (
845                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
846                and (_is_derived_table(node) or isinstance(node, exp.UDTF))
847            )
848            or isinstance(node, exp.UNWRAPPED_QUERIES)
849        ):
850            crossed_scope_boundary = True
851
852            if isinstance(node, (exp.Subquery, exp.UDTF)):
853                # The following args are not actually in the inner scope, so we should visit them
854                for key in ("joins", "laterals", "pivots"):
855                    for arg in node.args.get(key) or []:
856                        yield from walk_in_scope(arg, bfs=bfs)
857
858
859def find_all_in_scope(expression, expression_types, bfs=True):
860    """
861    Returns a generator object which visits all nodes in this scope and only yields those that
862    match at least one of the specified expression types.
863
864    This does NOT traverse into subscopes.
865
866    Args:
867        expression (exp.Expression):
868        expression_types (tuple[type]|type): the expression type(s) to match.
869        bfs (bool): True to use breadth-first search, False to use depth-first.
870
871    Yields:
872        exp.Expression: nodes
873    """
874    for expression in walk_in_scope(expression, bfs=bfs):
875        if isinstance(expression, tuple(ensure_collection(expression_types))):
876            yield expression
877
878
879def find_in_scope(expression, expression_types, bfs=True):
880    """
881    Returns the first node in this scope which matches at least one of the specified types.
882
883    This does NOT traverse into subscopes.
884
885    Args:
886        expression (exp.Expression):
887        expression_types (tuple[type]|type): the expression type(s) to match.
888        bfs (bool): True to use breadth-first search, False to use depth-first.
889
890    Returns:
891        exp.Expression: the node which matches the criteria or None if no node matching
892        the criteria was found.
893    """
894    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
logger = <Logger sqlglot (WARNING)>
TRAVERSABLES = (<class 'sqlglot.expressions.Query'>, <class 'sqlglot.expressions.DDL'>, <class 'sqlglot.expressions.DML'>)
class ScopeType(enum.Enum):
19class ScopeType(Enum):
20    ROOT = auto()
21    SUBQUERY = auto()
22    DERIVED_TABLE = auto()
23    CTE = auto()
24    UNION = auto()
25    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
class Scope:
 28class Scope:
 29    """
 30    Selection scope.
 31
 32    Attributes:
 33        expression (exp.Select|exp.SetOperation): Root expression of this scope
 34        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 35            a Table expression or another Scope instance. For example:
 36                SELECT * FROM x                     {"x": Table(this="x")}
 37                SELECT * FROM x AS y                {"y": Table(this="x")}
 38                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 39        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 40            For example:
 41                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 42            The LATERAL VIEW EXPLODE gets x as a source.
 43        cte_sources (dict[str, Scope]): Sources from CTES
 44        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 45            defines a column list for the alias of this scope, this is that list of columns.
 46            For example:
 47                SELECT * FROM (SELECT ...) AS y(col1, col2)
 48            The inner query would have `["col1", "col2"]` for its `outer_columns`
 49        parent (Scope): Parent scope
 50        scope_type (ScopeType): Type of this scope, relative to it's parent
 51        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 52        cte_scopes (list[Scope]): List of all child scopes for CTEs
 53        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 54        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 55        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 56        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 57            a list of the left and right child scopes.
 58    """
 59
 60    def __init__(
 61        self,
 62        expression,
 63        sources=None,
 64        outer_columns=None,
 65        parent=None,
 66        scope_type=ScopeType.ROOT,
 67        lateral_sources=None,
 68        cte_sources=None,
 69        can_be_correlated=None,
 70    ):
 71        self.expression = expression
 72        self.sources = sources or {}
 73        self.lateral_sources = lateral_sources or {}
 74        self.cte_sources = cte_sources or {}
 75        self.sources.update(self.lateral_sources)
 76        self.sources.update(self.cte_sources)
 77        self.outer_columns = outer_columns or []
 78        self.parent = parent
 79        self.scope_type = scope_type
 80        self.subquery_scopes = []
 81        self.derived_table_scopes = []
 82        self.table_scopes = []
 83        self.cte_scopes = []
 84        self.union_scopes = []
 85        self.udtf_scopes = []
 86        self.can_be_correlated = can_be_correlated
 87        self.clear_cache()
 88
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._join_hints = None
102        self._pivots = None
103        self._references = None
104        self._semi_anti_join_tables = None
105
106    def branch(
107        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
108    ):
109        """Branch from the current scope to a new, inner scope"""
110        return Scope(
111            expression=expression.unnest(),
112            sources=sources.copy() if sources else None,
113            parent=self,
114            scope_type=scope_type,
115            cte_sources={**self.cte_sources, **(cte_sources or {})},
116            lateral_sources=lateral_sources.copy() if lateral_sources else None,
117            can_be_correlated=self.can_be_correlated
118            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
119            **kwargs,
120        )
121
122    def _collect(self):
123        self._tables = []
124        self._ctes = []
125        self._subqueries = []
126        self._derived_tables = []
127        self._udtfs = []
128        self._raw_columns = []
129        self._stars = []
130        self._join_hints = []
131        self._semi_anti_join_tables = set()
132
133        for node in self.walk(bfs=False):
134            if node is self.expression:
135                continue
136
137            if isinstance(node, exp.Dot) and node.is_star:
138                self._stars.append(node)
139            elif isinstance(node, exp.Column):
140                if isinstance(node.this, exp.Star):
141                    self._stars.append(node)
142                else:
143                    self._raw_columns.append(node)
144            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
145                parent = node.parent
146                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
147                    self._semi_anti_join_tables.add(node.alias_or_name)
148
149                self._tables.append(node)
150            elif isinstance(node, exp.JoinHint):
151                self._join_hints.append(node)
152            elif isinstance(node, exp.UDTF):
153                self._udtfs.append(node)
154            elif isinstance(node, exp.CTE):
155                self._ctes.append(node)
156            elif _is_derived_table(node) and _is_from_or_join(node):
157                self._derived_tables.append(node)
158            elif isinstance(node, exp.UNWRAPPED_QUERIES):
159                self._subqueries.append(node)
160
161        self._collected = True
162
163    def _ensure_collected(self):
164        if not self._collected:
165            self._collect()
166
167    def walk(self, bfs=True, prune=None):
168        return walk_in_scope(self.expression, bfs=bfs, prune=None)
169
170    def find(self, *expression_types, bfs=True):
171        return find_in_scope(self.expression, expression_types, bfs=bfs)
172
173    def find_all(self, *expression_types, bfs=True):
174        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
175
176    def replace(self, old, new):
177        """
178        Replace `old` with `new`.
179
180        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
181
182        Args:
183            old (exp.Expression): old node
184            new (exp.Expression): new node
185        """
186        old.replace(new)
187        self.clear_cache()
188
189    @property
190    def tables(self):
191        """
192        List of tables in this scope.
193
194        Returns:
195            list[exp.Table]: tables
196        """
197        self._ensure_collected()
198        return self._tables
199
200    @property
201    def ctes(self):
202        """
203        List of CTEs in this scope.
204
205        Returns:
206            list[exp.CTE]: ctes
207        """
208        self._ensure_collected()
209        return self._ctes
210
211    @property
212    def derived_tables(self):
213        """
214        List of derived tables in this scope.
215
216        For example:
217            SELECT * FROM (SELECT ...) <- that's a derived table
218
219        Returns:
220            list[exp.Subquery]: derived tables
221        """
222        self._ensure_collected()
223        return self._derived_tables
224
225    @property
226    def udtfs(self):
227        """
228        List of "User Defined Tabular Functions" in this scope.
229
230        Returns:
231            list[exp.UDTF]: UDTFs
232        """
233        self._ensure_collected()
234        return self._udtfs
235
236    @property
237    def subqueries(self):
238        """
239        List of subqueries in this scope.
240
241        For example:
242            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
243
244        Returns:
245            list[exp.Select | exp.SetOperation]: subqueries
246        """
247        self._ensure_collected()
248        return self._subqueries
249
250    @property
251    def stars(self) -> t.List[exp.Column | exp.Dot]:
252        """
253        List of star expressions (columns or dots) in this scope.
254        """
255        self._ensure_collected()
256        return self._stars
257
258    @property
259    def columns(self):
260        """
261        List of columns in this scope.
262
263        Returns:
264            list[exp.Column]: Column instances in this scope, plus any
265                Columns that reference this scope from correlated subqueries.
266        """
267        if self._columns is None:
268            self._ensure_collected()
269            columns = self._raw_columns
270
271            external_columns = [
272                column
273                for scope in itertools.chain(
274                    self.subquery_scopes,
275                    self.udtf_scopes,
276                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
277                )
278                for column in scope.external_columns
279            ]
280
281            named_selects = set(self.expression.named_selects)
282
283            self._columns = []
284            for column in columns + external_columns:
285                ancestor = column.find_ancestor(
286                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star
287                )
288                if (
289                    not ancestor
290                    or column.table
291                    or isinstance(ancestor, exp.Select)
292                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
293                    or (
294                        isinstance(ancestor, exp.Order)
295                        and (
296                            isinstance(ancestor.parent, exp.Window)
297                            or column.name not in named_selects
298                        )
299                    )
300                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
301                ):
302                    self._columns.append(column)
303
304        return self._columns
305
306    @property
307    def selected_sources(self):
308        """
309        Mapping of nodes and sources that are actually selected from in this scope.
310
311        That is, all tables in a schema are selectable at any point. But a
312        table only becomes a selected source if it's included in a FROM or JOIN clause.
313
314        Returns:
315            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
316        """
317        if self._selected_sources is None:
318            result = {}
319
320            for name, node in self.references:
321                if name in self._semi_anti_join_tables:
322                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
323                    # selected source
324                    continue
325
326                if name in result:
327                    raise OptimizeError(f"Alias already used: {name}")
328                if name in self.sources:
329                    result[name] = (node, self.sources[name])
330
331            self._selected_sources = result
332        return self._selected_sources
333
334    @property
335    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
336        if self._references is None:
337            self._references = []
338
339            for table in self.tables:
340                self._references.append((table.alias_or_name, table))
341            for expression in itertools.chain(self.derived_tables, self.udtfs):
342                self._references.append(
343                    (
344                        expression.alias,
345                        expression if expression.args.get("pivots") else expression.unnest(),
346                    )
347                )
348
349        return self._references
350
351    @property
352    def external_columns(self):
353        """
354        Columns that appear to reference sources in outer scopes.
355
356        Returns:
357            list[exp.Column]: Column instances that don't reference
358                sources in the current scope.
359        """
360        if self._external_columns is None:
361            if isinstance(self.expression, exp.SetOperation):
362                left, right = self.union_scopes
363                self._external_columns = left.external_columns + right.external_columns
364            else:
365                self._external_columns = [
366                    c
367                    for c in self.columns
368                    if c.table not in self.selected_sources
369                    and c.table not in self.semi_or_anti_join_tables
370                ]
371
372        return self._external_columns
373
374    @property
375    def unqualified_columns(self):
376        """
377        Unqualified columns in the current scope.
378
379        Returns:
380             list[exp.Column]: Unqualified columns
381        """
382        return [c for c in self.columns if not c.table]
383
384    @property
385    def join_hints(self):
386        """
387        Hints that exist in the scope that reference tables
388
389        Returns:
390            list[exp.JoinHint]: Join hints that are referenced within the scope
391        """
392        if self._join_hints is None:
393            return []
394        return self._join_hints
395
396    @property
397    def pivots(self):
398        if not self._pivots:
399            self._pivots = [
400                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
401            ]
402
403        return self._pivots
404
405    @property
406    def semi_or_anti_join_tables(self):
407        return self._semi_anti_join_tables or set()
408
409    def source_columns(self, source_name):
410        """
411        Get all columns in the current scope for a particular source.
412
413        Args:
414            source_name (str): Name of the source
415        Returns:
416            list[exp.Column]: Column instances that reference `source_name`
417        """
418        return [column for column in self.columns if column.table == source_name]
419
420    @property
421    def is_subquery(self):
422        """Determine if this scope is a subquery"""
423        return self.scope_type == ScopeType.SUBQUERY
424
425    @property
426    def is_derived_table(self):
427        """Determine if this scope is a derived table"""
428        return self.scope_type == ScopeType.DERIVED_TABLE
429
430    @property
431    def is_union(self):
432        """Determine if this scope is a union"""
433        return self.scope_type == ScopeType.UNION
434
435    @property
436    def is_cte(self):
437        """Determine if this scope is a common table expression"""
438        return self.scope_type == ScopeType.CTE
439
440    @property
441    def is_root(self):
442        """Determine if this is the root scope"""
443        return self.scope_type == ScopeType.ROOT
444
445    @property
446    def is_udtf(self):
447        """Determine if this scope is a UDTF (User Defined Table Function)"""
448        return self.scope_type == ScopeType.UDTF
449
450    @property
451    def is_correlated_subquery(self):
452        """Determine if this scope is a correlated subquery"""
453        return bool(self.can_be_correlated and self.external_columns)
454
455    def rename_source(self, old_name, new_name):
456        """Rename a source in this scope"""
457        columns = self.sources.pop(old_name or "", [])
458        self.sources[new_name] = columns
459
460    def add_source(self, name, source):
461        """Add a source to this scope"""
462        self.sources[name] = source
463        self.clear_cache()
464
465    def remove_source(self, name):
466        """Remove a source from this scope"""
467        self.sources.pop(name, None)
468        self.clear_cache()
469
470    def __repr__(self):
471        return f"Scope<{self.expression.sql()}>"
472
473    def traverse(self):
474        """
475        Traverse the scope tree from this node.
476
477        Yields:
478            Scope: scope instances in depth-first-search post-order
479        """
480        stack = [self]
481        result = []
482        while stack:
483            scope = stack.pop()
484            result.append(scope)
485            stack.extend(
486                itertools.chain(
487                    scope.cte_scopes,
488                    scope.union_scopes,
489                    scope.table_scopes,
490                    scope.subquery_scopes,
491                )
492            )
493
494        yield from reversed(result)
495
496    def ref_count(self):
497        """
498        Count the number of times each scope in this tree is referenced.
499
500        Returns:
501            dict[int, int]: Mapping of Scope instance ID to reference count
502        """
503        scope_ref_count = defaultdict(lambda: 0)
504
505        for scope in self.traverse():
506            for _, source in scope.selected_sources.values():
507                scope_ref_count[id(source)] += 1
508
509        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.SetOperation): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • cte_sources (dict[str, Scope]): Sources from CTES
  • outer_columns (list[str]): If this is a derived table or CTE, and the outer query defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_columns
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_columns=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None, cte_sources=None, can_be_correlated=None)
60    def __init__(
61        self,
62        expression,
63        sources=None,
64        outer_columns=None,
65        parent=None,
66        scope_type=ScopeType.ROOT,
67        lateral_sources=None,
68        cte_sources=None,
69        can_be_correlated=None,
70    ):
71        self.expression = expression
72        self.sources = sources or {}
73        self.lateral_sources = lateral_sources or {}
74        self.cte_sources = cte_sources or {}
75        self.sources.update(self.lateral_sources)
76        self.sources.update(self.cte_sources)
77        self.outer_columns = outer_columns or []
78        self.parent = parent
79        self.scope_type = scope_type
80        self.subquery_scopes = []
81        self.derived_table_scopes = []
82        self.table_scopes = []
83        self.cte_scopes = []
84        self.union_scopes = []
85        self.udtf_scopes = []
86        self.can_be_correlated = can_be_correlated
87        self.clear_cache()
expression
sources
lateral_sources
cte_sources
outer_columns
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
can_be_correlated
def clear_cache(self):
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._join_hints = None
102        self._pivots = None
103        self._references = None
104        self._semi_anti_join_tables = None
def branch( self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs):
106    def branch(
107        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
108    ):
109        """Branch from the current scope to a new, inner scope"""
110        return Scope(
111            expression=expression.unnest(),
112            sources=sources.copy() if sources else None,
113            parent=self,
114            scope_type=scope_type,
115            cte_sources={**self.cte_sources, **(cte_sources or {})},
116            lateral_sources=lateral_sources.copy() if lateral_sources else None,
117            can_be_correlated=self.can_be_correlated
118            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
119            **kwargs,
120        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
167    def walk(self, bfs=True, prune=None):
168        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
170    def find(self, *expression_types, bfs=True):
171        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
173    def find_all(self, *expression_types, bfs=True):
174        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
176    def replace(self, old, new):
177        """
178        Replace `old` with `new`.
179
180        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
181
182        Args:
183            old (exp.Expression): old node
184            new (exp.Expression): new node
185        """
186        old.replace(new)
187        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables
189    @property
190    def tables(self):
191        """
192        List of tables in this scope.
193
194        Returns:
195            list[exp.Table]: tables
196        """
197        self._ensure_collected()
198        return self._tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes
200    @property
201    def ctes(self):
202        """
203        List of CTEs in this scope.
204
205        Returns:
206            list[exp.CTE]: ctes
207        """
208        self._ensure_collected()
209        return self._ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables
211    @property
212    def derived_tables(self):
213        """
214        List of derived tables in this scope.
215
216        For example:
217            SELECT * FROM (SELECT ...) <- that's a derived table
218
219        Returns:
220            list[exp.Subquery]: derived tables
221        """
222        self._ensure_collected()
223        return self._derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs
225    @property
226    def udtfs(self):
227        """
228        List of "User Defined Tabular Functions" in this scope.
229
230        Returns:
231            list[exp.UDTF]: UDTFs
232        """
233        self._ensure_collected()
234        return self._udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries
236    @property
237    def subqueries(self):
238        """
239        List of subqueries in this scope.
240
241        For example:
242            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
243
244        Returns:
245            list[exp.Select | exp.SetOperation]: subqueries
246        """
247        self._ensure_collected()
248        return self._subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Select | exp.SetOperation]: subqueries

250    @property
251    def stars(self) -> t.List[exp.Column | exp.Dot]:
252        """
253        List of star expressions (columns or dots) in this scope.
254        """
255        self._ensure_collected()
256        return self._stars

List of star expressions (columns or dots) in this scope.

columns
258    @property
259    def columns(self):
260        """
261        List of columns in this scope.
262
263        Returns:
264            list[exp.Column]: Column instances in this scope, plus any
265                Columns that reference this scope from correlated subqueries.
266        """
267        if self._columns is None:
268            self._ensure_collected()
269            columns = self._raw_columns
270
271            external_columns = [
272                column
273                for scope in itertools.chain(
274                    self.subquery_scopes,
275                    self.udtf_scopes,
276                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
277                )
278                for column in scope.external_columns
279            ]
280
281            named_selects = set(self.expression.named_selects)
282
283            self._columns = []
284            for column in columns + external_columns:
285                ancestor = column.find_ancestor(
286                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star
287                )
288                if (
289                    not ancestor
290                    or column.table
291                    or isinstance(ancestor, exp.Select)
292                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
293                    or (
294                        isinstance(ancestor, exp.Order)
295                        and (
296                            isinstance(ancestor.parent, exp.Window)
297                            or column.name not in named_selects
298                        )
299                    )
300                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
301                ):
302                    self._columns.append(column)
303
304        return self._columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources
306    @property
307    def selected_sources(self):
308        """
309        Mapping of nodes and sources that are actually selected from in this scope.
310
311        That is, all tables in a schema are selectable at any point. But a
312        table only becomes a selected source if it's included in a FROM or JOIN clause.
313
314        Returns:
315            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
316        """
317        if self._selected_sources is None:
318            result = {}
319
320            for name, node in self.references:
321                if name in self._semi_anti_join_tables:
322                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
323                    # selected source
324                    continue
325
326                if name in result:
327                    raise OptimizeError(f"Alias already used: {name}")
328                if name in self.sources:
329                    result[name] = (node, self.sources[name])
330
331            self._selected_sources = result
332        return self._selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
334    @property
335    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
336        if self._references is None:
337            self._references = []
338
339            for table in self.tables:
340                self._references.append((table.alias_or_name, table))
341            for expression in itertools.chain(self.derived_tables, self.udtfs):
342                self._references.append(
343                    (
344                        expression.alias,
345                        expression if expression.args.get("pivots") else expression.unnest(),
346                    )
347                )
348
349        return self._references
external_columns
351    @property
352    def external_columns(self):
353        """
354        Columns that appear to reference sources in outer scopes.
355
356        Returns:
357            list[exp.Column]: Column instances that don't reference
358                sources in the current scope.
359        """
360        if self._external_columns is None:
361            if isinstance(self.expression, exp.SetOperation):
362                left, right = self.union_scopes
363                self._external_columns = left.external_columns + right.external_columns
364            else:
365                self._external_columns = [
366                    c
367                    for c in self.columns
368                    if c.table not in self.selected_sources
369                    and c.table not in self.semi_or_anti_join_tables
370                ]
371
372        return self._external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns
374    @property
375    def unqualified_columns(self):
376        """
377        Unqualified columns in the current scope.
378
379        Returns:
380             list[exp.Column]: Unqualified columns
381        """
382        return [c for c in self.columns if not c.table]

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints
384    @property
385    def join_hints(self):
386        """
387        Hints that exist in the scope that reference tables
388
389        Returns:
390            list[exp.JoinHint]: Join hints that are referenced within the scope
391        """
392        if self._join_hints is None:
393            return []
394        return self._join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
396    @property
397    def pivots(self):
398        if not self._pivots:
399            self._pivots = [
400                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
401            ]
402
403        return self._pivots
semi_or_anti_join_tables
405    @property
406    def semi_or_anti_join_tables(self):
407        return self._semi_anti_join_tables or set()
def source_columns(self, source_name):
409    def source_columns(self, source_name):
410        """
411        Get all columns in the current scope for a particular source.
412
413        Args:
414            source_name (str): Name of the source
415        Returns:
416            list[exp.Column]: Column instances that reference `source_name`
417        """
418        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery
420    @property
421    def is_subquery(self):
422        """Determine if this scope is a subquery"""
423        return self.scope_type == ScopeType.SUBQUERY

Determine if this scope is a subquery

is_derived_table
425    @property
426    def is_derived_table(self):
427        """Determine if this scope is a derived table"""
428        return self.scope_type == ScopeType.DERIVED_TABLE

Determine if this scope is a derived table

is_union
430    @property
431    def is_union(self):
432        """Determine if this scope is a union"""
433        return self.scope_type == ScopeType.UNION

Determine if this scope is a union

is_cte
435    @property
436    def is_cte(self):
437        """Determine if this scope is a common table expression"""
438        return self.scope_type == ScopeType.CTE

Determine if this scope is a common table expression

is_root
440    @property
441    def is_root(self):
442        """Determine if this is the root scope"""
443        return self.scope_type == ScopeType.ROOT

Determine if this is the root scope

is_udtf
445    @property
446    def is_udtf(self):
447        """Determine if this scope is a UDTF (User Defined Table Function)"""
448        return self.scope_type == ScopeType.UDTF

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery
450    @property
451    def is_correlated_subquery(self):
452        """Determine if this scope is a correlated subquery"""
453        return bool(self.can_be_correlated and self.external_columns)

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
455    def rename_source(self, old_name, new_name):
456        """Rename a source in this scope"""
457        columns = self.sources.pop(old_name or "", [])
458        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
460    def add_source(self, name, source):
461        """Add a source to this scope"""
462        self.sources[name] = source
463        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
465    def remove_source(self, name):
466        """Remove a source from this scope"""
467        self.sources.pop(name, None)
468        self.clear_cache()

Remove a source from this scope

def traverse(self):
473    def traverse(self):
474        """
475        Traverse the scope tree from this node.
476
477        Yields:
478            Scope: scope instances in depth-first-search post-order
479        """
480        stack = [self]
481        result = []
482        while stack:
483            scope = stack.pop()
484            result.append(scope)
485            stack.extend(
486                itertools.chain(
487                    scope.cte_scopes,
488                    scope.union_scopes,
489                    scope.table_scopes,
490                    scope.subquery_scopes,
491                )
492            )
493
494        yield from reversed(result)

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
496    def ref_count(self):
497        """
498        Count the number of times each scope in this tree is referenced.
499
500        Returns:
501            dict[int, int]: Mapping of Scope instance ID to reference count
502        """
503        scope_ref_count = defaultdict(lambda: 0)
504
505        for scope in self.traverse():
506            for _, source in scope.selected_sources.values():
507                scope_ref_count[id(source)] += 1
508
509        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
512def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
513    """
514    Traverse an expression by its "scopes".
515
516    "Scope" represents the current context of a Select statement.
517
518    This is helpful for optimizing queries, where we need more information than
519    the expression tree itself. For example, we might care about the source
520    names within a subquery. Returns a list because a generator could result in
521    incomplete properties which is confusing.
522
523    Examples:
524        >>> import sqlglot
525        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
526        >>> scopes = traverse_scope(expression)
527        >>> scopes[0].expression.sql(), list(scopes[0].sources)
528        ('SELECT a FROM x', ['x'])
529        >>> scopes[1].expression.sql(), list(scopes[1].sources)
530        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
531
532    Args:
533        expression: Expression to traverse
534
535    Returns:
536        A list of the created scope instances
537    """
538    if isinstance(expression, TRAVERSABLES):
539        return list(_traverse_scope(Scope(expression)))
540    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression: Expression to traverse
Returns:

A list of the created scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
543def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
544    """
545    Build a scope tree.
546
547    Args:
548        expression: Expression to build the scope tree for.
549
550    Returns:
551        The root scope
552    """
553    return seq_get(traverse_scope(expression), -1)

Build a scope tree.

Arguments:
  • expression: Expression to build the scope tree for.
Returns:

The root scope

def walk_in_scope(expression, bfs=True, prune=None):
815def walk_in_scope(expression, bfs=True, prune=None):
816    """
817    Returns a generator object which visits all nodes in the syntrax tree, stopping at
818    nodes that start child scopes.
819
820    Args:
821        expression (exp.Expression):
822        bfs (bool): if set to True the BFS traversal order will be applied,
823            otherwise the DFS traversal will be used instead.
824        prune ((node, parent, arg_key) -> bool): callable that returns True if
825            the generator should stop traversing this branch of the tree.
826
827    Yields:
828        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
829    """
830    # We'll use this variable to pass state into the dfs generator.
831    # Whenever we set it to True, we exclude a subtree from traversal.
832    crossed_scope_boundary = False
833
834    for node in expression.walk(
835        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
836    ):
837        crossed_scope_boundary = False
838
839        yield node
840
841        if node is expression:
842            continue
843        if (
844            isinstance(node, exp.CTE)
845            or (
846                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
847                and (_is_derived_table(node) or isinstance(node, exp.UDTF))
848            )
849            or isinstance(node, exp.UNWRAPPED_QUERIES)
850        ):
851            crossed_scope_boundary = True
852
853            if isinstance(node, (exp.Subquery, exp.UDTF)):
854                # The following args are not actually in the inner scope, so we should visit them
855                for key in ("joins", "laterals", "pivots"):
856                    for arg in node.args.get(key) or []:
857                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
860def find_all_in_scope(expression, expression_types, bfs=True):
861    """
862    Returns a generator object which visits all nodes in this scope and only yields those that
863    match at least one of the specified expression types.
864
865    This does NOT traverse into subscopes.
866
867    Args:
868        expression (exp.Expression):
869        expression_types (tuple[type]|type): the expression type(s) to match.
870        bfs (bool): True to use breadth-first search, False to use depth-first.
871
872    Yields:
873        exp.Expression: nodes
874    """
875    for expression in walk_in_scope(expression, bfs=bfs):
876        if isinstance(expression, tuple(ensure_collection(expression_types))):
877            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
880def find_in_scope(expression, expression_types, bfs=True):
881    """
882    Returns the first node in this scope which matches at least one of the specified types.
883
884    This does NOT traverse into subscopes.
885
886    Args:
887        expression (exp.Expression):
888        expression_types (tuple[type]|type): the expression type(s) to match.
889        bfs (bool): True to use breadth-first search, False to use depth-first.
890
891    Returns:
892        exp.Expression: the node which matches the criteria or None if no node matching
893        the criteria was found.
894    """
895    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.