Edit on GitHub

sqlglot.optimizer.eliminate_ctes

 1from __future__ import annotations
 2
 3import typing as t
 4
 5from sqlglot.optimizer.scope import Scope, build_scope
 6
 7
 8if t.TYPE_CHECKING:
 9    from sqlglot._typing import E
10
11
12def eliminate_ctes(expression: E) -> E:
13    """
14    Remove unused CTEs from an expression.
15
16    Example:
17        >>> import sqlglot
18        >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
19        >>> expression = sqlglot.parse_one(sql)
20        >>> eliminate_ctes(expression).sql()
21        'SELECT a FROM z'
22
23    Args:
24        expression (sqlglot.Expr): expression to optimize
25    Returns:
26        sqlglot.Expr: optimized expression
27    """
28    root = build_scope(expression)
29
30    if root:
31        ref_count = root.ref_count()
32
33        # Traverse the scope tree in reverse so we can remove chains of unused CTEs
34        for scope in reversed(list(root.traverse())):
35            if scope.is_cte:
36                count = ref_count[id(scope)]
37                if count <= 0:
38                    cte_node = scope.expression.parent
39                    if not cte_node:
40                        continue
41                    with_node = cte_node.parent
42                    cte_node.pop()
43
44                    # Pop the entire WITH clause if this is the last CTE
45                    if with_node and len(with_node.expressions) <= 0:
46                        with_node.pop()
47
48                    # Decrement the ref count for all sources this CTE selects from
49                    for _, source in scope.selected_sources.values():
50                        if isinstance(source, Scope):
51                            ref_count[id(source)] -= 1
52
53    return expression
def eliminate_ctes(expression: ~E) -> ~E:
13def eliminate_ctes(expression: E) -> E:
14    """
15    Remove unused CTEs from an expression.
16
17    Example:
18        >>> import sqlglot
19        >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
20        >>> expression = sqlglot.parse_one(sql)
21        >>> eliminate_ctes(expression).sql()
22        'SELECT a FROM z'
23
24    Args:
25        expression (sqlglot.Expr): expression to optimize
26    Returns:
27        sqlglot.Expr: optimized expression
28    """
29    root = build_scope(expression)
30
31    if root:
32        ref_count = root.ref_count()
33
34        # Traverse the scope tree in reverse so we can remove chains of unused CTEs
35        for scope in reversed(list(root.traverse())):
36            if scope.is_cte:
37                count = ref_count[id(scope)]
38                if count <= 0:
39                    cte_node = scope.expression.parent
40                    if not cte_node:
41                        continue
42                    with_node = cte_node.parent
43                    cte_node.pop()
44
45                    # Pop the entire WITH clause if this is the last CTE
46                    if with_node and len(with_node.expressions) <= 0:
47                        with_node.pop()
48
49                    # Decrement the ref count for all sources this CTE selects from
50                    for _, source in scope.selected_sources.values():
51                        if isinstance(source, Scope):
52                            ref_count[id(source)] -= 1
53
54    return expression

Remove unused CTEs from an expression.

Example:
>>> import sqlglot
>>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
>>> expression = sqlglot.parse_one(sql)
>>> eliminate_ctes(expression).sql()
'SELECT a FROM z'
Arguments:
  • expression (sqlglot.Expr): expression to optimize
Returns:

sqlglot.Expr: optimized expression