Edit on GitHub

sqlglot.optimizer.eliminate_subqueries

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.helper import find_new_name
  8from sqlglot.optimizer.scope import Scope, build_scope
  9
 10if t.TYPE_CHECKING:
 11    ExistingCTEsMapping = dict[exp.Expr, str]
 12    TakenNameMapping = dict[str, t.Union[Scope, exp.Expr]]
 13
 14
 15def eliminate_subqueries(expression: exp.Expr) -> exp.Expr:
 16    """
 17    Rewrite derived tables as CTES, deduplicating if possible.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 22        >>> eliminate_subqueries(expression).sql()
 23        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 24
 25    This also deduplicates common subqueries:
 26        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
 27        >>> eliminate_subqueries(expression).sql()
 28        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
 29
 30    Args:
 31        expression (sqlglot.Expr): expression
 32    Returns:
 33        sqlglot.Expr: expression
 34    """
 35    if isinstance(expression, exp.Subquery):
 36        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 37        eliminate_subqueries(expression.this)
 38        return expression
 39
 40    root = build_scope(expression)
 41
 42    if not root:
 43        return expression
 44
 45    # Map of alias->Scope|Table
 46    # These are all aliases that are already used in the expression.
 47    # We don't want to create new CTEs that conflict with these names.
 48    taken: TakenNameMapping = {}
 49
 50    # All CTE aliases in the root scope are taken
 51    for scope in root.cte_scopes:
 52        parent = scope.expression.parent
 53        if parent:
 54            taken[parent.alias] = scope
 55
 56    # All table names are taken
 57    for scope in root.traverse():
 58        taken.update(
 59            {
 60                source.name: source
 61                for _, source in scope.sources.items()
 62                if isinstance(source, exp.Table)
 63            }
 64        )
 65
 66    # Map of Expr->alias
 67    # Existing CTES in the root expression. We'll use this for deduplication.
 68    existing_ctes: ExistingCTEsMapping = {}
 69
 70    with_ = root.expression.args.get("with_")
 71    recursive = False
 72    if with_:
 73        recursive = with_.args.get("recursive")
 74        for cte in with_.expressions:
 75            existing_ctes[cte.this] = cte.alias
 76    new_ctes = []
 77
 78    # We're adding more CTEs, but we want to maintain the DAG order.
 79    # Derived tables within an existing CTE need to come before the existing CTE.
 80    for cte_scope in root.cte_scopes:
 81        # Append all the new CTEs from this existing CTE
 82        for scope in cte_scope.traverse():
 83            if scope is cte_scope:
 84                # Don't try to eliminate this CTE itself
 85                continue
 86            new_cte = _eliminate(scope, existing_ctes, taken)
 87            if new_cte:
 88                new_ctes.append(new_cte)
 89
 90        # Append the existing CTE itself
 91        cte_parent = cte_scope.expression.parent
 92        if cte_parent:
 93            new_ctes.append(cte_parent)
 94
 95    # Now append the rest
 96    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 97        for child_scope in scope.traverse():
 98            new_cte = _eliminate(child_scope, existing_ctes, taken)
 99            if new_cte:
100                new_ctes.append(new_cte)
101
102    if new_ctes:
103        query = expression.expression if isinstance(expression, exp.DDL) else expression
104        query.set("with_", exp.With(expressions=new_ctes, recursive=recursive))
105
106    return expression
107
108
109def _eliminate(
110    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
111) -> exp.Expr | None:
112    if scope.is_derived_table:
113        return _eliminate_derived_table(scope, existing_ctes, taken)
114
115    if scope.is_cte:
116        return _eliminate_cte(scope, existing_ctes, taken)
117
118    return None
119
120
121def _eliminate_derived_table(
122    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
123) -> exp.Expr | None:
124    # This makes sure that we don't:
125    # - drop the "pivot" arg from a pivoted subquery
126    # - eliminate a lateral correlated subquery
127    parent_scope = scope.parent
128    if not parent_scope or parent_scope.pivots or isinstance(parent_scope.expression, exp.Lateral):
129        return None
130
131    expr_parent = scope.expression.parent
132    if not isinstance(expr_parent, exp.Subquery):
133        return None
134
135    # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
136    to_replace = expr_parent.unwrap()
137    name, cte = _new_cte(scope, existing_ctes, taken)
138    table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
139    table.set("joins", to_replace.args.get("joins"))
140
141    to_replace.replace(table)
142
143    return cte
144
145
146def _eliminate_cte(
147    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
148) -> exp.Expr | None:
149    parent = scope.expression.parent
150    if not parent:
151        return None
152    name, cte = _new_cte(scope, existing_ctes, taken)
153
154    with_ = parent.parent
155    parent.pop()
156    if with_ and not with_.expressions:
157        with_.pop()
158
159    # Rename references to this CTE
160    if not scope.parent:
161        return cte
162    for child_scope in scope.parent.traverse():
163        for table, source in child_scope.selected_sources.values():
164            if source is scope:
165                new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
166                table.replace(new_table)
167
168    return cte
169
170
171def _new_cte(
172    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
173) -> tuple[str, exp.Expr | None]:
174    """
175    Returns:
176        tuple of (name, cte)
177        where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
178        If this CTE duplicates an existing CTE, `cte` will be None.
179    """
180    duplicate_cte_alias = existing_ctes.get(scope.expression)
181    parent = scope.expression.parent
182    name = parent.alias if parent else ""
183
184    if not name:
185        name = find_new_name(taken=taken, base="cte")
186
187    if duplicate_cte_alias:
188        name = duplicate_cte_alias
189    elif taken.get(name):
190        name = find_new_name(taken=taken, base=name)
191
192    taken[name] = scope
193
194    if not duplicate_cte_alias:
195        existing_ctes[scope.expression] = name
196        cte = exp.CTE(
197            this=scope.expression,
198            alias=exp.TableAlias(this=exp.to_identifier(name)),
199        )
200    else:
201        cte = None
202    return name, cte
def eliminate_subqueries( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
 16def eliminate_subqueries(expression: exp.Expr) -> exp.Expr:
 17    """
 18    Rewrite derived tables as CTES, deduplicating if possible.
 19
 20    Example:
 21        >>> import sqlglot
 22        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 23        >>> eliminate_subqueries(expression).sql()
 24        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 25
 26    This also deduplicates common subqueries:
 27        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
 28        >>> eliminate_subqueries(expression).sql()
 29        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
 30
 31    Args:
 32        expression (sqlglot.Expr): expression
 33    Returns:
 34        sqlglot.Expr: expression
 35    """
 36    if isinstance(expression, exp.Subquery):
 37        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 38        eliminate_subqueries(expression.this)
 39        return expression
 40
 41    root = build_scope(expression)
 42
 43    if not root:
 44        return expression
 45
 46    # Map of alias->Scope|Table
 47    # These are all aliases that are already used in the expression.
 48    # We don't want to create new CTEs that conflict with these names.
 49    taken: TakenNameMapping = {}
 50
 51    # All CTE aliases in the root scope are taken
 52    for scope in root.cte_scopes:
 53        parent = scope.expression.parent
 54        if parent:
 55            taken[parent.alias] = scope
 56
 57    # All table names are taken
 58    for scope in root.traverse():
 59        taken.update(
 60            {
 61                source.name: source
 62                for _, source in scope.sources.items()
 63                if isinstance(source, exp.Table)
 64            }
 65        )
 66
 67    # Map of Expr->alias
 68    # Existing CTES in the root expression. We'll use this for deduplication.
 69    existing_ctes: ExistingCTEsMapping = {}
 70
 71    with_ = root.expression.args.get("with_")
 72    recursive = False
 73    if with_:
 74        recursive = with_.args.get("recursive")
 75        for cte in with_.expressions:
 76            existing_ctes[cte.this] = cte.alias
 77    new_ctes = []
 78
 79    # We're adding more CTEs, but we want to maintain the DAG order.
 80    # Derived tables within an existing CTE need to come before the existing CTE.
 81    for cte_scope in root.cte_scopes:
 82        # Append all the new CTEs from this existing CTE
 83        for scope in cte_scope.traverse():
 84            if scope is cte_scope:
 85                # Don't try to eliminate this CTE itself
 86                continue
 87            new_cte = _eliminate(scope, existing_ctes, taken)
 88            if new_cte:
 89                new_ctes.append(new_cte)
 90
 91        # Append the existing CTE itself
 92        cte_parent = cte_scope.expression.parent
 93        if cte_parent:
 94            new_ctes.append(cte_parent)
 95
 96    # Now append the rest
 97    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 98        for child_scope in scope.traverse():
 99            new_cte = _eliminate(child_scope, existing_ctes, taken)
100            if new_cte:
101                new_ctes.append(new_cte)
102
103    if new_ctes:
104        query = expression.expression if isinstance(expression, exp.DDL) else expression
105        query.set("with_", exp.With(expressions=new_ctes, recursive=recursive))
106
107    return expression

Rewrite derived tables as CTES, deduplicating if possible.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Arguments:
  • expression (sqlglot.Expr): expression
Returns:

sqlglot.Expr: expression