Edit on GitHub

sqlglot.optimizer.eliminate_ctes

 1from __future__ import annotations
 2
 3import typing as t
 4
 5from sqlglot.optimizer.scope import Scope, build_scope
 6
 7
 8if t.TYPE_CHECKING:
 9    from sqlglot._typing import E
10
11
12def eliminate_ctes(expression: E) -> E:
13    """
14    Remove unused CTEs from an expression.
15
16    Example:
17        >>> import sqlglot
18        >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
19        >>> expression = sqlglot.parse_one(sql)
20        >>> eliminate_ctes(expression).sql()
21        'SELECT a FROM z'
22
23    Args:
24        expression (sqlglot.Expression): expression to optimize
25    Returns:
26        sqlglot.Expression: optimized expression
27    """
28    root = build_scope(expression)
29
30    if root:
31        ref_count = root.ref_count()
32
33        # Traverse the scope tree in reverse so we can remove chains of unused CTEs
34        for scope in reversed(list(root.traverse())):
35            if scope.is_cte:
36                count = ref_count[id(scope)]
37                if count <= 0:
38                    cte_node = scope.expression.parent
39                    with_node = cte_node.parent
40                    cte_node.pop()
41
42                    # Pop the entire WITH clause if this is the last CTE
43                    if with_node and len(with_node.expressions) <= 0:
44                        with_node.pop()
45
46                    # Decrement the ref count for all sources this CTE selects from
47                    for _, source in scope.selected_sources.values():
48                        if isinstance(source, Scope):
49                            ref_count[id(source)] -= 1
50
51    return expression
def eliminate_ctes(expression: ~E) -> ~E:
13def eliminate_ctes(expression: E) -> E:
14    """
15    Remove unused CTEs from an expression.
16
17    Example:
18        >>> import sqlglot
19        >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
20        >>> expression = sqlglot.parse_one(sql)
21        >>> eliminate_ctes(expression).sql()
22        'SELECT a FROM z'
23
24    Args:
25        expression (sqlglot.Expression): expression to optimize
26    Returns:
27        sqlglot.Expression: optimized expression
28    """
29    root = build_scope(expression)
30
31    if root:
32        ref_count = root.ref_count()
33
34        # Traverse the scope tree in reverse so we can remove chains of unused CTEs
35        for scope in reversed(list(root.traverse())):
36            if scope.is_cte:
37                count = ref_count[id(scope)]
38                if count <= 0:
39                    cte_node = scope.expression.parent
40                    with_node = cte_node.parent
41                    cte_node.pop()
42
43                    # Pop the entire WITH clause if this is the last CTE
44                    if with_node and len(with_node.expressions) <= 0:
45                        with_node.pop()
46
47                    # Decrement the ref count for all sources this CTE selects from
48                    for _, source in scope.selected_sources.values():
49                        if isinstance(source, Scope):
50                            ref_count[id(source)] -= 1
51
52    return expression

Remove unused CTEs from an expression.

Example:
>>> import sqlglot
>>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
>>> expression = sqlglot.parse_one(sql)
>>> eliminate_ctes(expression).sql()
'SELECT a FROM z'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression