Edit on GitHub

sqlglot.optimizer.scope

  1from __future__ import annotations
  2
  3import itertools
  4import logging
  5import typing as t
  6from collections import defaultdict
  7from enum import Enum, auto
  8
  9from sqlglot import exp
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import ensure_collection, find_new_name, seq_get
 12
 13logger = logging.getLogger("sqlglot")
 14
 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
 16
 17
 18class ScopeType(Enum):
 19    ROOT = auto()
 20    SUBQUERY = auto()
 21    DERIVED_TABLE = auto()
 22    CTE = auto()
 23    UNION = auto()
 24    UDTF = auto()
 25
 26
 27class Scope:
 28    """
 29    Selection scope.
 30
 31    Attributes:
 32        expression (exp.Select|exp.SetOperation): Root expression of this scope
 33        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 34            a Table expression or another Scope instance. For example:
 35                SELECT * FROM x                     {"x": Table(this="x")}
 36                SELECT * FROM x AS y                {"y": Table(this="x")}
 37                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 38        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 39            For example:
 40                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 41            The LATERAL VIEW EXPLODE gets x as a source.
 42        cte_sources (dict[str, Scope]): Sources from CTES
 43        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 44            defines a column list for the alias of this scope, this is that list of columns.
 45            For example:
 46                SELECT * FROM (SELECT ...) AS y(col1, col2)
 47            The inner query would have `["col1", "col2"]` for its `outer_columns`
 48        parent (Scope): Parent scope
 49        scope_type (ScopeType): Type of this scope, relative to it's parent
 50        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 51        cte_scopes (list[Scope]): List of all child scopes for CTEs
 52        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 53        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 54        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 55        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 56            a list of the left and right child scopes.
 57    """
 58
 59    def __init__(
 60        self,
 61        expression,
 62        sources=None,
 63        outer_columns=None,
 64        parent=None,
 65        scope_type=ScopeType.ROOT,
 66        lateral_sources=None,
 67        cte_sources=None,
 68        can_be_correlated=None,
 69    ):
 70        self.expression = expression
 71        self.sources = sources or {}
 72        self.lateral_sources = lateral_sources or {}
 73        self.cte_sources = cte_sources or {}
 74        self.sources.update(self.lateral_sources)
 75        self.sources.update(self.cte_sources)
 76        self.outer_columns = outer_columns or []
 77        self.parent = parent
 78        self.scope_type = scope_type
 79        self.subquery_scopes = []
 80        self.derived_table_scopes = []
 81        self.table_scopes = []
 82        self.cte_scopes = []
 83        self.union_scopes = []
 84        self.udtf_scopes = []
 85        self.can_be_correlated = can_be_correlated
 86        self.clear_cache()
 87
 88    def clear_cache(self):
 89        self._collected = False
 90        self._raw_columns = None
 91        self._table_columns = None
 92        self._stars = None
 93        self._derived_tables = None
 94        self._udtfs = None
 95        self._tables = None
 96        self._ctes = None
 97        self._subqueries = None
 98        self._selected_sources = None
 99        self._columns = None
100        self._external_columns = None
101        self._join_hints = None
102        self._pivots = None
103        self._references = None
104        self._semi_anti_join_tables = None
105
106    def branch(
107        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
108    ):
109        """Branch from the current scope to a new, inner scope"""
110        return Scope(
111            expression=expression.unnest(),
112            sources=sources.copy() if sources else None,
113            parent=self,
114            scope_type=scope_type,
115            cte_sources={**self.cte_sources, **(cte_sources or {})},
116            lateral_sources=lateral_sources.copy() if lateral_sources else None,
117            can_be_correlated=self.can_be_correlated
118            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
119            **kwargs,
120        )
121
122    def _collect(self):
123        self._tables = []
124        self._ctes = []
125        self._subqueries = []
126        self._derived_tables = []
127        self._udtfs = []
128        self._raw_columns = []
129        self._table_columns = []
130        self._stars = []
131        self._join_hints = []
132        self._semi_anti_join_tables = set()
133
134        for node in self.walk(bfs=False):
135            if node is self.expression:
136                continue
137
138            if isinstance(node, exp.Dot) and node.is_star:
139                self._stars.append(node)
140            elif isinstance(node, exp.Column):
141                if isinstance(node.this, exp.Star):
142                    self._stars.append(node)
143                else:
144                    self._raw_columns.append(node)
145            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
146                parent = node.parent
147                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
148                    self._semi_anti_join_tables.add(node.alias_or_name)
149
150                self._tables.append(node)
151            elif isinstance(node, exp.JoinHint):
152                self._join_hints.append(node)
153            elif isinstance(node, exp.UDTF):
154                self._udtfs.append(node)
155            elif isinstance(node, exp.CTE):
156                self._ctes.append(node)
157            elif _is_derived_table(node) and _is_from_or_join(node):
158                self._derived_tables.append(node)
159            elif isinstance(node, exp.UNWRAPPED_QUERIES):
160                self._subqueries.append(node)
161            elif isinstance(node, exp.TableColumn):
162                self._table_columns.append(node)
163
164        self._collected = True
165
166    def _ensure_collected(self):
167        if not self._collected:
168            self._collect()
169
170    def walk(self, bfs=True, prune=None):
171        return walk_in_scope(self.expression, bfs=bfs, prune=None)
172
173    def find(self, *expression_types, bfs=True):
174        return find_in_scope(self.expression, expression_types, bfs=bfs)
175
176    def find_all(self, *expression_types, bfs=True):
177        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
178
179    def replace(self, old, new):
180        """
181        Replace `old` with `new`.
182
183        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
184
185        Args:
186            old (exp.Expression): old node
187            new (exp.Expression): new node
188        """
189        old.replace(new)
190        self.clear_cache()
191
192    @property
193    def tables(self):
194        """
195        List of tables in this scope.
196
197        Returns:
198            list[exp.Table]: tables
199        """
200        self._ensure_collected()
201        return self._tables
202
203    @property
204    def ctes(self):
205        """
206        List of CTEs in this scope.
207
208        Returns:
209            list[exp.CTE]: ctes
210        """
211        self._ensure_collected()
212        return self._ctes
213
214    @property
215    def derived_tables(self):
216        """
217        List of derived tables in this scope.
218
219        For example:
220            SELECT * FROM (SELECT ...) <- that's a derived table
221
222        Returns:
223            list[exp.Subquery]: derived tables
224        """
225        self._ensure_collected()
226        return self._derived_tables
227
228    @property
229    def udtfs(self):
230        """
231        List of "User Defined Tabular Functions" in this scope.
232
233        Returns:
234            list[exp.UDTF]: UDTFs
235        """
236        self._ensure_collected()
237        return self._udtfs
238
239    @property
240    def subqueries(self):
241        """
242        List of subqueries in this scope.
243
244        For example:
245            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
246
247        Returns:
248            list[exp.Select | exp.SetOperation]: subqueries
249        """
250        self._ensure_collected()
251        return self._subqueries
252
253    @property
254    def stars(self) -> t.List[exp.Column | exp.Dot]:
255        """
256        List of star expressions (columns or dots) in this scope.
257        """
258        self._ensure_collected()
259        return self._stars
260
261    @property
262    def columns(self):
263        """
264        List of columns in this scope.
265
266        Returns:
267            list[exp.Column]: Column instances in this scope, plus any
268                Columns that reference this scope from correlated subqueries.
269        """
270        if self._columns is None:
271            self._ensure_collected()
272            columns = self._raw_columns
273
274            external_columns = [
275                column
276                for scope in itertools.chain(
277                    self.subquery_scopes,
278                    self.udtf_scopes,
279                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
280                )
281                for column in scope.external_columns
282            ]
283
284            named_selects = set(self.expression.named_selects)
285
286            self._columns = []
287            for column in columns + external_columns:
288                ancestor = column.find_ancestor(
289                    exp.Select,
290                    exp.Qualify,
291                    exp.Order,
292                    exp.Having,
293                    exp.Hint,
294                    exp.Table,
295                    exp.Star,
296                    exp.Distinct,
297                )
298                if (
299                    not ancestor
300                    or column.table
301                    or isinstance(ancestor, exp.Select)
302                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
303                    or (
304                        isinstance(ancestor, (exp.Order, exp.Distinct))
305                        and (
306                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
307                            or column.name not in named_selects
308                        )
309                    )
310                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
311                ):
312                    self._columns.append(column)
313
314        return self._columns
315
316    @property
317    def table_columns(self):
318        if self._table_columns is None:
319            self._ensure_collected()
320
321        return self._table_columns
322
323    @property
324    def selected_sources(self):
325        """
326        Mapping of nodes and sources that are actually selected from in this scope.
327
328        That is, all tables in a schema are selectable at any point. But a
329        table only becomes a selected source if it's included in a FROM or JOIN clause.
330
331        Returns:
332            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
333        """
334        if self._selected_sources is None:
335            result = {}
336
337            for name, node in self.references:
338                if name in self._semi_anti_join_tables:
339                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
340                    # selected source
341                    continue
342
343                if name in result:
344                    raise OptimizeError(f"Alias already used: {name}")
345                if name in self.sources:
346                    result[name] = (node, self.sources[name])
347
348            self._selected_sources = result
349        return self._selected_sources
350
351    @property
352    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
353        if self._references is None:
354            self._references = []
355
356            for table in self.tables:
357                self._references.append((table.alias_or_name, table))
358            for expression in itertools.chain(self.derived_tables, self.udtfs):
359                self._references.append(
360                    (
361                        _get_source_alias(expression),
362                        expression if expression.args.get("pivots") else expression.unnest(),
363                    )
364                )
365
366        return self._references
367
368    @property
369    def external_columns(self):
370        """
371        Columns that appear to reference sources in outer scopes.
372
373        Returns:
374            list[exp.Column]: Column instances that don't reference
375                sources in the current scope.
376        """
377        if self._external_columns is None:
378            if isinstance(self.expression, exp.SetOperation):
379                left, right = self.union_scopes
380                self._external_columns = left.external_columns + right.external_columns
381            else:
382                self._external_columns = [
383                    c
384                    for c in self.columns
385                    if c.table not in self.selected_sources
386                    and c.table not in self.semi_or_anti_join_tables
387                ]
388
389        return self._external_columns
390
391    @property
392    def unqualified_columns(self):
393        """
394        Unqualified columns in the current scope.
395
396        Returns:
397             list[exp.Column]: Unqualified columns
398        """
399        return [c for c in self.columns if not c.table]
400
401    @property
402    def join_hints(self):
403        """
404        Hints that exist in the scope that reference tables
405
406        Returns:
407            list[exp.JoinHint]: Join hints that are referenced within the scope
408        """
409        if self._join_hints is None:
410            return []
411        return self._join_hints
412
413    @property
414    def pivots(self):
415        if not self._pivots:
416            self._pivots = [
417                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
418            ]
419
420        return self._pivots
421
422    @property
423    def semi_or_anti_join_tables(self):
424        return self._semi_anti_join_tables or set()
425
426    def source_columns(self, source_name):
427        """
428        Get all columns in the current scope for a particular source.
429
430        Args:
431            source_name (str): Name of the source
432        Returns:
433            list[exp.Column]: Column instances that reference `source_name`
434        """
435        return [column for column in self.columns if column.table == source_name]
436
437    @property
438    def is_subquery(self):
439        """Determine if this scope is a subquery"""
440        return self.scope_type == ScopeType.SUBQUERY
441
442    @property
443    def is_derived_table(self):
444        """Determine if this scope is a derived table"""
445        return self.scope_type == ScopeType.DERIVED_TABLE
446
447    @property
448    def is_union(self):
449        """Determine if this scope is a union"""
450        return self.scope_type == ScopeType.UNION
451
452    @property
453    def is_cte(self):
454        """Determine if this scope is a common table expression"""
455        return self.scope_type == ScopeType.CTE
456
457    @property
458    def is_root(self):
459        """Determine if this is the root scope"""
460        return self.scope_type == ScopeType.ROOT
461
462    @property
463    def is_udtf(self):
464        """Determine if this scope is a UDTF (User Defined Table Function)"""
465        return self.scope_type == ScopeType.UDTF
466
467    @property
468    def is_correlated_subquery(self):
469        """Determine if this scope is a correlated subquery"""
470        return bool(self.can_be_correlated and self.external_columns)
471
472    def rename_source(self, old_name, new_name):
473        """Rename a source in this scope"""
474        columns = self.sources.pop(old_name or "", [])
475        self.sources[new_name] = columns
476
477    def add_source(self, name, source):
478        """Add a source to this scope"""
479        self.sources[name] = source
480        self.clear_cache()
481
482    def remove_source(self, name):
483        """Remove a source from this scope"""
484        self.sources.pop(name, None)
485        self.clear_cache()
486
487    def __repr__(self):
488        return f"Scope<{self.expression.sql()}>"
489
490    def traverse(self):
491        """
492        Traverse the scope tree from this node.
493
494        Yields:
495            Scope: scope instances in depth-first-search post-order
496        """
497        stack = [self]
498        result = []
499        while stack:
500            scope = stack.pop()
501            result.append(scope)
502            stack.extend(
503                itertools.chain(
504                    scope.cte_scopes,
505                    scope.union_scopes,
506                    scope.table_scopes,
507                    scope.subquery_scopes,
508                )
509            )
510
511        yield from reversed(result)
512
513    def ref_count(self):
514        """
515        Count the number of times each scope in this tree is referenced.
516
517        Returns:
518            dict[int, int]: Mapping of Scope instance ID to reference count
519        """
520        scope_ref_count = defaultdict(lambda: 0)
521
522        for scope in self.traverse():
523            for _, source in scope.selected_sources.values():
524                scope_ref_count[id(source)] += 1
525
526            for name in scope._semi_anti_join_tables:
527                # semi/anti join sources are not actually selected but we still need to
528                # increment their ref count to avoid them being optimized away
529                if name in scope.sources:
530                    scope_ref_count[id(scope.sources[name])] += 1
531
532        return scope_ref_count
533
534
535def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
536    """
537    Traverse an expression by its "scopes".
538
539    "Scope" represents the current context of a Select statement.
540
541    This is helpful for optimizing queries, where we need more information than
542    the expression tree itself. For example, we might care about the source
543    names within a subquery. Returns a list because a generator could result in
544    incomplete properties which is confusing.
545
546    Examples:
547        >>> import sqlglot
548        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
549        >>> scopes = traverse_scope(expression)
550        >>> scopes[0].expression.sql(), list(scopes[0].sources)
551        ('SELECT a FROM x', ['x'])
552        >>> scopes[1].expression.sql(), list(scopes[1].sources)
553        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
554
555    Args:
556        expression: Expression to traverse
557
558    Returns:
559        A list of the created scope instances
560    """
561    if isinstance(expression, TRAVERSABLES):
562        return list(_traverse_scope(Scope(expression)))
563    return []
564
565
566def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
567    """
568    Build a scope tree.
569
570    Args:
571        expression: Expression to build the scope tree for.
572
573    Returns:
574        The root scope
575    """
576    return seq_get(traverse_scope(expression), -1)
577
578
579def _traverse_scope(scope):
580    expression = scope.expression
581
582    if isinstance(expression, exp.Select):
583        yield from _traverse_select(scope)
584    elif isinstance(expression, exp.SetOperation):
585        yield from _traverse_ctes(scope)
586        yield from _traverse_union(scope)
587        return
588    elif isinstance(expression, exp.Subquery):
589        if scope.is_root:
590            yield from _traverse_select(scope)
591        else:
592            yield from _traverse_subqueries(scope)
593    elif isinstance(expression, exp.Table):
594        yield from _traverse_tables(scope)
595    elif isinstance(expression, exp.UDTF):
596        yield from _traverse_udtfs(scope)
597    elif isinstance(expression, exp.DDL):
598        if isinstance(expression.expression, exp.Query):
599            yield from _traverse_ctes(scope)
600            yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
601        return
602    elif isinstance(expression, exp.DML):
603        yield from _traverse_ctes(scope)
604        for query in find_all_in_scope(expression, exp.Query):
605            # This check ensures we don't yield the CTE/nested queries twice
606            if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
607                yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
608        return
609    else:
610        logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
611        return
612
613    yield scope
614
615
616def _traverse_select(scope):
617    yield from _traverse_ctes(scope)
618    yield from _traverse_tables(scope)
619    yield from _traverse_subqueries(scope)
620
621
622def _traverse_union(scope):
623    prev_scope = None
624    union_scope_stack = [scope]
625    expression_stack = [scope.expression.right, scope.expression.left]
626
627    while expression_stack:
628        expression = expression_stack.pop()
629        union_scope = union_scope_stack[-1]
630
631        new_scope = union_scope.branch(
632            expression,
633            outer_columns=union_scope.outer_columns,
634            scope_type=ScopeType.UNION,
635        )
636
637        if isinstance(expression, exp.SetOperation):
638            yield from _traverse_ctes(new_scope)
639
640            union_scope_stack.append(new_scope)
641            expression_stack.extend([expression.right, expression.left])
642            continue
643
644        for scope in _traverse_scope(new_scope):
645            yield scope
646
647        if prev_scope:
648            union_scope_stack.pop()
649            union_scope.union_scopes = [prev_scope, scope]
650            prev_scope = union_scope
651
652            yield union_scope
653        else:
654            prev_scope = scope
655
656
657def _traverse_ctes(scope):
658    sources = {}
659
660    for cte in scope.ctes:
661        cte_name = cte.alias
662
663        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
664        # thus the recursive scope is the first section of the union.
665        with_ = scope.expression.args.get("with")
666        if with_ and with_.recursive:
667            union = cte.this
668
669            if isinstance(union, exp.SetOperation):
670                sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
671
672        child_scope = None
673
674        for child_scope in _traverse_scope(
675            scope.branch(
676                cte.this,
677                cte_sources=sources,
678                outer_columns=cte.alias_column_names,
679                scope_type=ScopeType.CTE,
680            )
681        ):
682            yield child_scope
683
684        # append the final child_scope yielded
685        if child_scope:
686            sources[cte_name] = child_scope
687            scope.cte_scopes.append(child_scope)
688
689    scope.sources.update(sources)
690    scope.cte_sources.update(sources)
691
692
693def _is_derived_table(expression: exp.Subquery) -> bool:
694    """
695    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
696    as it doesn't introduce a new scope. If an alias is present, it shadows all names
697    under the Subquery, so that's one exception to this rule.
698    """
699    return isinstance(expression, exp.Subquery) and bool(
700        expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
701    )
702
703
704def _is_from_or_join(expression: exp.Expression) -> bool:
705    """
706    Determine if `expression` is the FROM or JOIN clause of a SELECT statement.
707    """
708    parent = expression.parent
709
710    # Subqueries can be arbitrarily nested
711    while isinstance(parent, exp.Subquery):
712        parent = parent.parent
713
714    return isinstance(parent, (exp.From, exp.Join))
715
716
717def _traverse_tables(scope):
718    sources = {}
719
720    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
721    expressions = []
722    from_ = scope.expression.args.get("from")
723    if from_:
724        expressions.append(from_.this)
725
726    for join in scope.expression.args.get("joins") or []:
727        expressions.append(join.this)
728
729    if isinstance(scope.expression, exp.Table):
730        expressions.append(scope.expression)
731
732    expressions.extend(scope.expression.args.get("laterals") or [])
733
734    for expression in expressions:
735        if isinstance(expression, exp.Final):
736            expression = expression.this
737        if isinstance(expression, exp.Table):
738            table_name = expression.name
739            source_name = expression.alias_or_name
740
741            if table_name in scope.sources and not expression.db:
742                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
743                # it is pivoted, because then we get back a new table and hence a new source.
744                pivots = expression.args.get("pivots")
745                if pivots:
746                    sources[pivots[0].alias] = expression
747                else:
748                    sources[source_name] = scope.sources[table_name]
749            elif source_name in sources:
750                sources[find_new_name(sources, table_name)] = expression
751            else:
752                sources[source_name] = expression
753
754            # Make sure to not include the joins twice
755            if expression is not scope.expression:
756                expressions.extend(join.this for join in expression.args.get("joins") or [])
757
758            continue
759
760        if not isinstance(expression, exp.DerivedTable):
761            continue
762
763        if isinstance(expression, exp.UDTF):
764            lateral_sources = sources
765            scope_type = ScopeType.UDTF
766            scopes = scope.udtf_scopes
767        elif _is_derived_table(expression):
768            lateral_sources = None
769            scope_type = ScopeType.DERIVED_TABLE
770            scopes = scope.derived_table_scopes
771            expressions.extend(join.this for join in expression.args.get("joins") or [])
772        else:
773            # Makes sure we check for possible sources in nested table constructs
774            expressions.append(expression.this)
775            expressions.extend(join.this for join in expression.args.get("joins") or [])
776            continue
777
778        child_scope = None
779
780        for child_scope in _traverse_scope(
781            scope.branch(
782                expression,
783                lateral_sources=lateral_sources,
784                outer_columns=expression.alias_column_names,
785                scope_type=scope_type,
786            )
787        ):
788            yield child_scope
789
790            # Tables without aliases will be set as ""
791            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
792            # Until then, this means that only a single, unaliased derived table is allowed (rather,
793            # the latest one wins.
794            sources[_get_source_alias(expression)] = child_scope
795
796        # append the final child_scope yielded
797        if child_scope:
798            scopes.append(child_scope)
799            scope.table_scopes.append(child_scope)
800
801    scope.sources.update(sources)
802
803
804def _traverse_subqueries(scope):
805    for subquery in scope.subqueries:
806        top = None
807        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
808            yield child_scope
809            top = child_scope
810        scope.subquery_scopes.append(top)
811
812
813def _traverse_udtfs(scope):
814    if isinstance(scope.expression, exp.Unnest):
815        expressions = scope.expression.expressions
816    elif isinstance(scope.expression, exp.Lateral):
817        expressions = [scope.expression.this]
818    else:
819        expressions = []
820
821    sources = {}
822    for expression in expressions:
823        if _is_derived_table(expression):
824            top = None
825            for child_scope in _traverse_scope(
826                scope.branch(
827                    expression,
828                    scope_type=ScopeType.SUBQUERY,
829                    outer_columns=expression.alias_column_names,
830                )
831            ):
832                yield child_scope
833                top = child_scope
834                sources[_get_source_alias(expression)] = child_scope
835
836            scope.subquery_scopes.append(top)
837
838    scope.sources.update(sources)
839
840
841def walk_in_scope(expression, bfs=True, prune=None):
842    """
843    Returns a generator object which visits all nodes in the syntrax tree, stopping at
844    nodes that start child scopes.
845
846    Args:
847        expression (exp.Expression):
848        bfs (bool): if set to True the BFS traversal order will be applied,
849            otherwise the DFS traversal will be used instead.
850        prune ((node, parent, arg_key) -> bool): callable that returns True if
851            the generator should stop traversing this branch of the tree.
852
853    Yields:
854        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
855    """
856    # We'll use this variable to pass state into the dfs generator.
857    # Whenever we set it to True, we exclude a subtree from traversal.
858    crossed_scope_boundary = False
859
860    for node in expression.walk(
861        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
862    ):
863        crossed_scope_boundary = False
864
865        yield node
866
867        if node is expression:
868            continue
869
870        if (
871            isinstance(node, exp.CTE)
872            or (
873                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
874                and _is_derived_table(node)
875            )
876            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
877            or isinstance(node, exp.UNWRAPPED_QUERIES)
878        ):
879            crossed_scope_boundary = True
880
881            if isinstance(node, (exp.Subquery, exp.UDTF)):
882                # The following args are not actually in the inner scope, so we should visit them
883                for key in ("joins", "laterals", "pivots"):
884                    for arg in node.args.get(key) or []:
885                        yield from walk_in_scope(arg, bfs=bfs)
886
887
888def find_all_in_scope(expression, expression_types, bfs=True):
889    """
890    Returns a generator object which visits all nodes in this scope and only yields those that
891    match at least one of the specified expression types.
892
893    This does NOT traverse into subscopes.
894
895    Args:
896        expression (exp.Expression):
897        expression_types (tuple[type]|type): the expression type(s) to match.
898        bfs (bool): True to use breadth-first search, False to use depth-first.
899
900    Yields:
901        exp.Expression: nodes
902    """
903    for expression in walk_in_scope(expression, bfs=bfs):
904        if isinstance(expression, tuple(ensure_collection(expression_types))):
905            yield expression
906
907
908def find_in_scope(expression, expression_types, bfs=True):
909    """
910    Returns the first node in this scope which matches at least one of the specified types.
911
912    This does NOT traverse into subscopes.
913
914    Args:
915        expression (exp.Expression):
916        expression_types (tuple[type]|type): the expression type(s) to match.
917        bfs (bool): True to use breadth-first search, False to use depth-first.
918
919    Returns:
920        exp.Expression: the node which matches the criteria or None if no node matching
921        the criteria was found.
922    """
923    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
924
925
926def _get_source_alias(expression):
927    alias_arg = expression.args.get("alias")
928    alias_name = expression.alias
929
930    if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1:
931        alias_name = alias_arg.columns[0].name
932
933    return alias_name
logger = <Logger sqlglot (WARNING)>
TRAVERSABLES = (<class 'sqlglot.expressions.Query'>, <class 'sqlglot.expressions.DDL'>, <class 'sqlglot.expressions.DML'>)
class ScopeType(enum.Enum):
19class ScopeType(Enum):
20    ROOT = auto()
21    SUBQUERY = auto()
22    DERIVED_TABLE = auto()
23    CTE = auto()
24    UNION = auto()
25    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
class Scope:
 28class Scope:
 29    """
 30    Selection scope.
 31
 32    Attributes:
 33        expression (exp.Select|exp.SetOperation): Root expression of this scope
 34        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 35            a Table expression or another Scope instance. For example:
 36                SELECT * FROM x                     {"x": Table(this="x")}
 37                SELECT * FROM x AS y                {"y": Table(this="x")}
 38                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 39        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 40            For example:
 41                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 42            The LATERAL VIEW EXPLODE gets x as a source.
 43        cte_sources (dict[str, Scope]): Sources from CTES
 44        outer_columns (list[str]): If this is a derived table or CTE, and the outer query
 45            defines a column list for the alias of this scope, this is that list of columns.
 46            For example:
 47                SELECT * FROM (SELECT ...) AS y(col1, col2)
 48            The inner query would have `["col1", "col2"]` for its `outer_columns`
 49        parent (Scope): Parent scope
 50        scope_type (ScopeType): Type of this scope, relative to it's parent
 51        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 52        cte_scopes (list[Scope]): List of all child scopes for CTEs
 53        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 54        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 55        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 56        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 57            a list of the left and right child scopes.
 58    """
 59
 60    def __init__(
 61        self,
 62        expression,
 63        sources=None,
 64        outer_columns=None,
 65        parent=None,
 66        scope_type=ScopeType.ROOT,
 67        lateral_sources=None,
 68        cte_sources=None,
 69        can_be_correlated=None,
 70    ):
 71        self.expression = expression
 72        self.sources = sources or {}
 73        self.lateral_sources = lateral_sources or {}
 74        self.cte_sources = cte_sources or {}
 75        self.sources.update(self.lateral_sources)
 76        self.sources.update(self.cte_sources)
 77        self.outer_columns = outer_columns or []
 78        self.parent = parent
 79        self.scope_type = scope_type
 80        self.subquery_scopes = []
 81        self.derived_table_scopes = []
 82        self.table_scopes = []
 83        self.cte_scopes = []
 84        self.union_scopes = []
 85        self.udtf_scopes = []
 86        self.can_be_correlated = can_be_correlated
 87        self.clear_cache()
 88
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._join_hints = None
103        self._pivots = None
104        self._references = None
105        self._semi_anti_join_tables = None
106
107    def branch(
108        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
109    ):
110        """Branch from the current scope to a new, inner scope"""
111        return Scope(
112            expression=expression.unnest(),
113            sources=sources.copy() if sources else None,
114            parent=self,
115            scope_type=scope_type,
116            cte_sources={**self.cte_sources, **(cte_sources or {})},
117            lateral_sources=lateral_sources.copy() if lateral_sources else None,
118            can_be_correlated=self.can_be_correlated
119            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
120            **kwargs,
121        )
122
123    def _collect(self):
124        self._tables = []
125        self._ctes = []
126        self._subqueries = []
127        self._derived_tables = []
128        self._udtfs = []
129        self._raw_columns = []
130        self._table_columns = []
131        self._stars = []
132        self._join_hints = []
133        self._semi_anti_join_tables = set()
134
135        for node in self.walk(bfs=False):
136            if node is self.expression:
137                continue
138
139            if isinstance(node, exp.Dot) and node.is_star:
140                self._stars.append(node)
141            elif isinstance(node, exp.Column):
142                if isinstance(node.this, exp.Star):
143                    self._stars.append(node)
144                else:
145                    self._raw_columns.append(node)
146            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
147                parent = node.parent
148                if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
149                    self._semi_anti_join_tables.add(node.alias_or_name)
150
151                self._tables.append(node)
152            elif isinstance(node, exp.JoinHint):
153                self._join_hints.append(node)
154            elif isinstance(node, exp.UDTF):
155                self._udtfs.append(node)
156            elif isinstance(node, exp.CTE):
157                self._ctes.append(node)
158            elif _is_derived_table(node) and _is_from_or_join(node):
159                self._derived_tables.append(node)
160            elif isinstance(node, exp.UNWRAPPED_QUERIES):
161                self._subqueries.append(node)
162            elif isinstance(node, exp.TableColumn):
163                self._table_columns.append(node)
164
165        self._collected = True
166
167    def _ensure_collected(self):
168        if not self._collected:
169            self._collect()
170
171    def walk(self, bfs=True, prune=None):
172        return walk_in_scope(self.expression, bfs=bfs, prune=None)
173
174    def find(self, *expression_types, bfs=True):
175        return find_in_scope(self.expression, expression_types, bfs=bfs)
176
177    def find_all(self, *expression_types, bfs=True):
178        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
179
180    def replace(self, old, new):
181        """
182        Replace `old` with `new`.
183
184        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
185
186        Args:
187            old (exp.Expression): old node
188            new (exp.Expression): new node
189        """
190        old.replace(new)
191        self.clear_cache()
192
193    @property
194    def tables(self):
195        """
196        List of tables in this scope.
197
198        Returns:
199            list[exp.Table]: tables
200        """
201        self._ensure_collected()
202        return self._tables
203
204    @property
205    def ctes(self):
206        """
207        List of CTEs in this scope.
208
209        Returns:
210            list[exp.CTE]: ctes
211        """
212        self._ensure_collected()
213        return self._ctes
214
215    @property
216    def derived_tables(self):
217        """
218        List of derived tables in this scope.
219
220        For example:
221            SELECT * FROM (SELECT ...) <- that's a derived table
222
223        Returns:
224            list[exp.Subquery]: derived tables
225        """
226        self._ensure_collected()
227        return self._derived_tables
228
229    @property
230    def udtfs(self):
231        """
232        List of "User Defined Tabular Functions" in this scope.
233
234        Returns:
235            list[exp.UDTF]: UDTFs
236        """
237        self._ensure_collected()
238        return self._udtfs
239
240    @property
241    def subqueries(self):
242        """
243        List of subqueries in this scope.
244
245        For example:
246            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
247
248        Returns:
249            list[exp.Select | exp.SetOperation]: subqueries
250        """
251        self._ensure_collected()
252        return self._subqueries
253
254    @property
255    def stars(self) -> t.List[exp.Column | exp.Dot]:
256        """
257        List of star expressions (columns or dots) in this scope.
258        """
259        self._ensure_collected()
260        return self._stars
261
262    @property
263    def columns(self):
264        """
265        List of columns in this scope.
266
267        Returns:
268            list[exp.Column]: Column instances in this scope, plus any
269                Columns that reference this scope from correlated subqueries.
270        """
271        if self._columns is None:
272            self._ensure_collected()
273            columns = self._raw_columns
274
275            external_columns = [
276                column
277                for scope in itertools.chain(
278                    self.subquery_scopes,
279                    self.udtf_scopes,
280                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
281                )
282                for column in scope.external_columns
283            ]
284
285            named_selects = set(self.expression.named_selects)
286
287            self._columns = []
288            for column in columns + external_columns:
289                ancestor = column.find_ancestor(
290                    exp.Select,
291                    exp.Qualify,
292                    exp.Order,
293                    exp.Having,
294                    exp.Hint,
295                    exp.Table,
296                    exp.Star,
297                    exp.Distinct,
298                )
299                if (
300                    not ancestor
301                    or column.table
302                    or isinstance(ancestor, exp.Select)
303                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
304                    or (
305                        isinstance(ancestor, (exp.Order, exp.Distinct))
306                        and (
307                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
308                            or column.name not in named_selects
309                        )
310                    )
311                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
312                ):
313                    self._columns.append(column)
314
315        return self._columns
316
317    @property
318    def table_columns(self):
319        if self._table_columns is None:
320            self._ensure_collected()
321
322        return self._table_columns
323
324    @property
325    def selected_sources(self):
326        """
327        Mapping of nodes and sources that are actually selected from in this scope.
328
329        That is, all tables in a schema are selectable at any point. But a
330        table only becomes a selected source if it's included in a FROM or JOIN clause.
331
332        Returns:
333            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
334        """
335        if self._selected_sources is None:
336            result = {}
337
338            for name, node in self.references:
339                if name in self._semi_anti_join_tables:
340                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
341                    # selected source
342                    continue
343
344                if name in result:
345                    raise OptimizeError(f"Alias already used: {name}")
346                if name in self.sources:
347                    result[name] = (node, self.sources[name])
348
349            self._selected_sources = result
350        return self._selected_sources
351
352    @property
353    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
354        if self._references is None:
355            self._references = []
356
357            for table in self.tables:
358                self._references.append((table.alias_or_name, table))
359            for expression in itertools.chain(self.derived_tables, self.udtfs):
360                self._references.append(
361                    (
362                        _get_source_alias(expression),
363                        expression if expression.args.get("pivots") else expression.unnest(),
364                    )
365                )
366
367        return self._references
368
369    @property
370    def external_columns(self):
371        """
372        Columns that appear to reference sources in outer scopes.
373
374        Returns:
375            list[exp.Column]: Column instances that don't reference
376                sources in the current scope.
377        """
378        if self._external_columns is None:
379            if isinstance(self.expression, exp.SetOperation):
380                left, right = self.union_scopes
381                self._external_columns = left.external_columns + right.external_columns
382            else:
383                self._external_columns = [
384                    c
385                    for c in self.columns
386                    if c.table not in self.selected_sources
387                    and c.table not in self.semi_or_anti_join_tables
388                ]
389
390        return self._external_columns
391
392    @property
393    def unqualified_columns(self):
394        """
395        Unqualified columns in the current scope.
396
397        Returns:
398             list[exp.Column]: Unqualified columns
399        """
400        return [c for c in self.columns if not c.table]
401
402    @property
403    def join_hints(self):
404        """
405        Hints that exist in the scope that reference tables
406
407        Returns:
408            list[exp.JoinHint]: Join hints that are referenced within the scope
409        """
410        if self._join_hints is None:
411            return []
412        return self._join_hints
413
414    @property
415    def pivots(self):
416        if not self._pivots:
417            self._pivots = [
418                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
419            ]
420
421        return self._pivots
422
423    @property
424    def semi_or_anti_join_tables(self):
425        return self._semi_anti_join_tables or set()
426
427    def source_columns(self, source_name):
428        """
429        Get all columns in the current scope for a particular source.
430
431        Args:
432            source_name (str): Name of the source
433        Returns:
434            list[exp.Column]: Column instances that reference `source_name`
435        """
436        return [column for column in self.columns if column.table == source_name]
437
438    @property
439    def is_subquery(self):
440        """Determine if this scope is a subquery"""
441        return self.scope_type == ScopeType.SUBQUERY
442
443    @property
444    def is_derived_table(self):
445        """Determine if this scope is a derived table"""
446        return self.scope_type == ScopeType.DERIVED_TABLE
447
448    @property
449    def is_union(self):
450        """Determine if this scope is a union"""
451        return self.scope_type == ScopeType.UNION
452
453    @property
454    def is_cte(self):
455        """Determine if this scope is a common table expression"""
456        return self.scope_type == ScopeType.CTE
457
458    @property
459    def is_root(self):
460        """Determine if this is the root scope"""
461        return self.scope_type == ScopeType.ROOT
462
463    @property
464    def is_udtf(self):
465        """Determine if this scope is a UDTF (User Defined Table Function)"""
466        return self.scope_type == ScopeType.UDTF
467
468    @property
469    def is_correlated_subquery(self):
470        """Determine if this scope is a correlated subquery"""
471        return bool(self.can_be_correlated and self.external_columns)
472
473    def rename_source(self, old_name, new_name):
474        """Rename a source in this scope"""
475        columns = self.sources.pop(old_name or "", [])
476        self.sources[new_name] = columns
477
478    def add_source(self, name, source):
479        """Add a source to this scope"""
480        self.sources[name] = source
481        self.clear_cache()
482
483    def remove_source(self, name):
484        """Remove a source from this scope"""
485        self.sources.pop(name, None)
486        self.clear_cache()
487
488    def __repr__(self):
489        return f"Scope<{self.expression.sql()}>"
490
491    def traverse(self):
492        """
493        Traverse the scope tree from this node.
494
495        Yields:
496            Scope: scope instances in depth-first-search post-order
497        """
498        stack = [self]
499        result = []
500        while stack:
501            scope = stack.pop()
502            result.append(scope)
503            stack.extend(
504                itertools.chain(
505                    scope.cte_scopes,
506                    scope.union_scopes,
507                    scope.table_scopes,
508                    scope.subquery_scopes,
509                )
510            )
511
512        yield from reversed(result)
513
514    def ref_count(self):
515        """
516        Count the number of times each scope in this tree is referenced.
517
518        Returns:
519            dict[int, int]: Mapping of Scope instance ID to reference count
520        """
521        scope_ref_count = defaultdict(lambda: 0)
522
523        for scope in self.traverse():
524            for _, source in scope.selected_sources.values():
525                scope_ref_count[id(source)] += 1
526
527            for name in scope._semi_anti_join_tables:
528                # semi/anti join sources are not actually selected but we still need to
529                # increment their ref count to avoid them being optimized away
530                if name in scope.sources:
531                    scope_ref_count[id(scope.sources[name])] += 1
532
533        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.SetOperation): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • cte_sources (dict[str, Scope]): Sources from CTES
  • outer_columns (list[str]): If this is a derived table or CTE, and the outer query defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_columns
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_columns=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None, cte_sources=None, can_be_correlated=None)
60    def __init__(
61        self,
62        expression,
63        sources=None,
64        outer_columns=None,
65        parent=None,
66        scope_type=ScopeType.ROOT,
67        lateral_sources=None,
68        cte_sources=None,
69        can_be_correlated=None,
70    ):
71        self.expression = expression
72        self.sources = sources or {}
73        self.lateral_sources = lateral_sources or {}
74        self.cte_sources = cte_sources or {}
75        self.sources.update(self.lateral_sources)
76        self.sources.update(self.cte_sources)
77        self.outer_columns = outer_columns or []
78        self.parent = parent
79        self.scope_type = scope_type
80        self.subquery_scopes = []
81        self.derived_table_scopes = []
82        self.table_scopes = []
83        self.cte_scopes = []
84        self.union_scopes = []
85        self.udtf_scopes = []
86        self.can_be_correlated = can_be_correlated
87        self.clear_cache()
expression
sources
lateral_sources
cte_sources
outer_columns
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
can_be_correlated
def clear_cache(self):
 89    def clear_cache(self):
 90        self._collected = False
 91        self._raw_columns = None
 92        self._table_columns = None
 93        self._stars = None
 94        self._derived_tables = None
 95        self._udtfs = None
 96        self._tables = None
 97        self._ctes = None
 98        self._subqueries = None
 99        self._selected_sources = None
100        self._columns = None
101        self._external_columns = None
102        self._join_hints = None
103        self._pivots = None
104        self._references = None
105        self._semi_anti_join_tables = None
def branch( self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs):
107    def branch(
108        self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
109    ):
110        """Branch from the current scope to a new, inner scope"""
111        return Scope(
112            expression=expression.unnest(),
113            sources=sources.copy() if sources else None,
114            parent=self,
115            scope_type=scope_type,
116            cte_sources={**self.cte_sources, **(cte_sources or {})},
117            lateral_sources=lateral_sources.copy() if lateral_sources else None,
118            can_be_correlated=self.can_be_correlated
119            or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
120            **kwargs,
121        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True, prune=None):
171    def walk(self, bfs=True, prune=None):
172        return walk_in_scope(self.expression, bfs=bfs, prune=None)
def find(self, *expression_types, bfs=True):
174    def find(self, *expression_types, bfs=True):
175        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
177    def find_all(self, *expression_types, bfs=True):
178        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
180    def replace(self, old, new):
181        """
182        Replace `old` with `new`.
183
184        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
185
186        Args:
187            old (exp.Expression): old node
188            new (exp.Expression): new node
189        """
190        old.replace(new)
191        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables
193    @property
194    def tables(self):
195        """
196        List of tables in this scope.
197
198        Returns:
199            list[exp.Table]: tables
200        """
201        self._ensure_collected()
202        return self._tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes
204    @property
205    def ctes(self):
206        """
207        List of CTEs in this scope.
208
209        Returns:
210            list[exp.CTE]: ctes
211        """
212        self._ensure_collected()
213        return self._ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables
215    @property
216    def derived_tables(self):
217        """
218        List of derived tables in this scope.
219
220        For example:
221            SELECT * FROM (SELECT ...) <- that's a derived table
222
223        Returns:
224            list[exp.Subquery]: derived tables
225        """
226        self._ensure_collected()
227        return self._derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs
229    @property
230    def udtfs(self):
231        """
232        List of "User Defined Tabular Functions" in this scope.
233
234        Returns:
235            list[exp.UDTF]: UDTFs
236        """
237        self._ensure_collected()
238        return self._udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries
240    @property
241    def subqueries(self):
242        """
243        List of subqueries in this scope.
244
245        For example:
246            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
247
248        Returns:
249            list[exp.Select | exp.SetOperation]: subqueries
250        """
251        self._ensure_collected()
252        return self._subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Select | exp.SetOperation]: subqueries

254    @property
255    def stars(self) -> t.List[exp.Column | exp.Dot]:
256        """
257        List of star expressions (columns or dots) in this scope.
258        """
259        self._ensure_collected()
260        return self._stars

List of star expressions (columns or dots) in this scope.

columns
262    @property
263    def columns(self):
264        """
265        List of columns in this scope.
266
267        Returns:
268            list[exp.Column]: Column instances in this scope, plus any
269                Columns that reference this scope from correlated subqueries.
270        """
271        if self._columns is None:
272            self._ensure_collected()
273            columns = self._raw_columns
274
275            external_columns = [
276                column
277                for scope in itertools.chain(
278                    self.subquery_scopes,
279                    self.udtf_scopes,
280                    (dts for dts in self.derived_table_scopes if dts.can_be_correlated),
281                )
282                for column in scope.external_columns
283            ]
284
285            named_selects = set(self.expression.named_selects)
286
287            self._columns = []
288            for column in columns + external_columns:
289                ancestor = column.find_ancestor(
290                    exp.Select,
291                    exp.Qualify,
292                    exp.Order,
293                    exp.Having,
294                    exp.Hint,
295                    exp.Table,
296                    exp.Star,
297                    exp.Distinct,
298                )
299                if (
300                    not ancestor
301                    or column.table
302                    or isinstance(ancestor, exp.Select)
303                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
304                    or (
305                        isinstance(ancestor, (exp.Order, exp.Distinct))
306                        and (
307                            isinstance(ancestor.parent, (exp.Window, exp.WithinGroup))
308                            or column.name not in named_selects
309                        )
310                    )
311                    or (isinstance(ancestor, exp.Star) and not column.arg_key == "except")
312                ):
313                    self._columns.append(column)
314
315        return self._columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

table_columns
317    @property
318    def table_columns(self):
319        if self._table_columns is None:
320            self._ensure_collected()
321
322        return self._table_columns
selected_sources
324    @property
325    def selected_sources(self):
326        """
327        Mapping of nodes and sources that are actually selected from in this scope.
328
329        That is, all tables in a schema are selectable at any point. But a
330        table only becomes a selected source if it's included in a FROM or JOIN clause.
331
332        Returns:
333            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
334        """
335        if self._selected_sources is None:
336            result = {}
337
338            for name, node in self.references:
339                if name in self._semi_anti_join_tables:
340                    # The RHS table of SEMI/ANTI joins shouldn't be collected as a
341                    # selected source
342                    continue
343
344                if name in result:
345                    raise OptimizeError(f"Alias already used: {name}")
346                if name in self.sources:
347                    result[name] = (node, self.sources[name])
348
349            self._selected_sources = result
350        return self._selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
352    @property
353    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
354        if self._references is None:
355            self._references = []
356
357            for table in self.tables:
358                self._references.append((table.alias_or_name, table))
359            for expression in itertools.chain(self.derived_tables, self.udtfs):
360                self._references.append(
361                    (
362                        _get_source_alias(expression),
363                        expression if expression.args.get("pivots") else expression.unnest(),
364                    )
365                )
366
367        return self._references
external_columns
369    @property
370    def external_columns(self):
371        """
372        Columns that appear to reference sources in outer scopes.
373
374        Returns:
375            list[exp.Column]: Column instances that don't reference
376                sources in the current scope.
377        """
378        if self._external_columns is None:
379            if isinstance(self.expression, exp.SetOperation):
380                left, right = self.union_scopes
381                self._external_columns = left.external_columns + right.external_columns
382            else:
383                self._external_columns = [
384                    c
385                    for c in self.columns
386                    if c.table not in self.selected_sources
387                    and c.table not in self.semi_or_anti_join_tables
388                ]
389
390        return self._external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns
392    @property
393    def unqualified_columns(self):
394        """
395        Unqualified columns in the current scope.
396
397        Returns:
398             list[exp.Column]: Unqualified columns
399        """
400        return [c for c in self.columns if not c.table]

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints
402    @property
403    def join_hints(self):
404        """
405        Hints that exist in the scope that reference tables
406
407        Returns:
408            list[exp.JoinHint]: Join hints that are referenced within the scope
409        """
410        if self._join_hints is None:
411            return []
412        return self._join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
414    @property
415    def pivots(self):
416        if not self._pivots:
417            self._pivots = [
418                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
419            ]
420
421        return self._pivots
semi_or_anti_join_tables
423    @property
424    def semi_or_anti_join_tables(self):
425        return self._semi_anti_join_tables or set()
def source_columns(self, source_name):
427    def source_columns(self, source_name):
428        """
429        Get all columns in the current scope for a particular source.
430
431        Args:
432            source_name (str): Name of the source
433        Returns:
434            list[exp.Column]: Column instances that reference `source_name`
435        """
436        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery
438    @property
439    def is_subquery(self):
440        """Determine if this scope is a subquery"""
441        return self.scope_type == ScopeType.SUBQUERY

Determine if this scope is a subquery

is_derived_table
443    @property
444    def is_derived_table(self):
445        """Determine if this scope is a derived table"""
446        return self.scope_type == ScopeType.DERIVED_TABLE

Determine if this scope is a derived table

is_union
448    @property
449    def is_union(self):
450        """Determine if this scope is a union"""
451        return self.scope_type == ScopeType.UNION

Determine if this scope is a union

is_cte
453    @property
454    def is_cte(self):
455        """Determine if this scope is a common table expression"""
456        return self.scope_type == ScopeType.CTE

Determine if this scope is a common table expression

is_root
458    @property
459    def is_root(self):
460        """Determine if this is the root scope"""
461        return self.scope_type == ScopeType.ROOT

Determine if this is the root scope

is_udtf
463    @property
464    def is_udtf(self):
465        """Determine if this scope is a UDTF (User Defined Table Function)"""
466        return self.scope_type == ScopeType.UDTF

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery
468    @property
469    def is_correlated_subquery(self):
470        """Determine if this scope is a correlated subquery"""
471        return bool(self.can_be_correlated and self.external_columns)

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
473    def rename_source(self, old_name, new_name):
474        """Rename a source in this scope"""
475        columns = self.sources.pop(old_name or "", [])
476        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
478    def add_source(self, name, source):
479        """Add a source to this scope"""
480        self.sources[name] = source
481        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
483    def remove_source(self, name):
484        """Remove a source from this scope"""
485        self.sources.pop(name, None)
486        self.clear_cache()

Remove a source from this scope

def traverse(self):
491    def traverse(self):
492        """
493        Traverse the scope tree from this node.
494
495        Yields:
496            Scope: scope instances in depth-first-search post-order
497        """
498        stack = [self]
499        result = []
500        while stack:
501            scope = stack.pop()
502            result.append(scope)
503            stack.extend(
504                itertools.chain(
505                    scope.cte_scopes,
506                    scope.union_scopes,
507                    scope.table_scopes,
508                    scope.subquery_scopes,
509                )
510            )
511
512        yield from reversed(result)

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
514    def ref_count(self):
515        """
516        Count the number of times each scope in this tree is referenced.
517
518        Returns:
519            dict[int, int]: Mapping of Scope instance ID to reference count
520        """
521        scope_ref_count = defaultdict(lambda: 0)
522
523        for scope in self.traverse():
524            for _, source in scope.selected_sources.values():
525                scope_ref_count[id(source)] += 1
526
527            for name in scope._semi_anti_join_tables:
528                # semi/anti join sources are not actually selected but we still need to
529                # increment their ref count to avoid them being optimized away
530                if name in scope.sources:
531                    scope_ref_count[id(scope.sources[name])] += 1
532
533        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[Scope]:
536def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
537    """
538    Traverse an expression by its "scopes".
539
540    "Scope" represents the current context of a Select statement.
541
542    This is helpful for optimizing queries, where we need more information than
543    the expression tree itself. For example, we might care about the source
544    names within a subquery. Returns a list because a generator could result in
545    incomplete properties which is confusing.
546
547    Examples:
548        >>> import sqlglot
549        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
550        >>> scopes = traverse_scope(expression)
551        >>> scopes[0].expression.sql(), list(scopes[0].sources)
552        ('SELECT a FROM x', ['x'])
553        >>> scopes[1].expression.sql(), list(scopes[1].sources)
554        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
555
556    Args:
557        expression: Expression to traverse
558
559    Returns:
560        A list of the created scope instances
561    """
562    if isinstance(expression, TRAVERSABLES):
563        return list(_traverse_scope(Scope(expression)))
564    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression: Expression to traverse
Returns:

A list of the created scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[Scope]:
567def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
568    """
569    Build a scope tree.
570
571    Args:
572        expression: Expression to build the scope tree for.
573
574    Returns:
575        The root scope
576    """
577    return seq_get(traverse_scope(expression), -1)

Build a scope tree.

Arguments:
  • expression: Expression to build the scope tree for.
Returns:

The root scope

def walk_in_scope(expression, bfs=True, prune=None):
842def walk_in_scope(expression, bfs=True, prune=None):
843    """
844    Returns a generator object which visits all nodes in the syntrax tree, stopping at
845    nodes that start child scopes.
846
847    Args:
848        expression (exp.Expression):
849        bfs (bool): if set to True the BFS traversal order will be applied,
850            otherwise the DFS traversal will be used instead.
851        prune ((node, parent, arg_key) -> bool): callable that returns True if
852            the generator should stop traversing this branch of the tree.
853
854    Yields:
855        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
856    """
857    # We'll use this variable to pass state into the dfs generator.
858    # Whenever we set it to True, we exclude a subtree from traversal.
859    crossed_scope_boundary = False
860
861    for node in expression.walk(
862        bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
863    ):
864        crossed_scope_boundary = False
865
866        yield node
867
868        if node is expression:
869            continue
870
871        if (
872            isinstance(node, exp.CTE)
873            or (
874                isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
875                and _is_derived_table(node)
876            )
877            or (isinstance(node.parent, exp.UDTF) and isinstance(node, exp.Query))
878            or isinstance(node, exp.UNWRAPPED_QUERIES)
879        ):
880            crossed_scope_boundary = True
881
882            if isinstance(node, (exp.Subquery, exp.UDTF)):
883                # The following args are not actually in the inner scope, so we should visit them
884                for key in ("joins", "laterals", "pivots"):
885                    for arg in node.args.get(key) or []:
886                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
  • prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
889def find_all_in_scope(expression, expression_types, bfs=True):
890    """
891    Returns a generator object which visits all nodes in this scope and only yields those that
892    match at least one of the specified expression types.
893
894    This does NOT traverse into subscopes.
895
896    Args:
897        expression (exp.Expression):
898        expression_types (tuple[type]|type): the expression type(s) to match.
899        bfs (bool): True to use breadth-first search, False to use depth-first.
900
901    Yields:
902        exp.Expression: nodes
903    """
904    for expression in walk_in_scope(expression, bfs=bfs):
905        if isinstance(expression, tuple(ensure_collection(expression_types))):
906            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
909def find_in_scope(expression, expression_types, bfs=True):
910    """
911    Returns the first node in this scope which matches at least one of the specified types.
912
913    This does NOT traverse into subscopes.
914
915    Args:
916        expression (exp.Expression):
917        expression_types (tuple[type]|type): the expression type(s) to match.
918        bfs (bool): True to use breadth-first search, False to use depth-first.
919
920    Returns:
921        exp.Expression: the node which matches the criteria or None if no node matching
922        the criteria was found.
923    """
924    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.