Edit on GitHub

sqlglot.optimizer.eliminate_subqueries

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.helper import find_new_name
  8from sqlglot.optimizer.scope import Scope, build_scope
  9
 10if t.TYPE_CHECKING:
 11    ExistingCTEsMapping = t.Dict[exp.Expression, str]
 12    TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
 13
 14
 15def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
 16    """
 17    Rewrite derived tables as CTES, deduplicating if possible.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 22        >>> eliminate_subqueries(expression).sql()
 23        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 24
 25    This also deduplicates common subqueries:
 26        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
 27        >>> eliminate_subqueries(expression).sql()
 28        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
 29
 30    Args:
 31        expression (sqlglot.Expression): expression
 32    Returns:
 33        sqlglot.Expression: expression
 34    """
 35    if isinstance(expression, exp.Subquery):
 36        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 37        eliminate_subqueries(expression.this)
 38        return expression
 39
 40    root = build_scope(expression)
 41
 42    if not root:
 43        return expression
 44
 45    # Map of alias->Scope|Table
 46    # These are all aliases that are already used in the expression.
 47    # We don't want to create new CTEs that conflict with these names.
 48    taken: TakenNameMapping = {}
 49
 50    # All CTE aliases in the root scope are taken
 51    for scope in root.cte_scopes:
 52        taken[scope.expression.parent.alias] = scope
 53
 54    # All table names are taken
 55    for scope in root.traverse():
 56        taken.update(
 57            {
 58                source.name: source
 59                for _, source in scope.sources.items()
 60                if isinstance(source, exp.Table)
 61            }
 62        )
 63
 64    # Map of Expression->alias
 65    # Existing CTES in the root expression. We'll use this for deduplication.
 66    existing_ctes: ExistingCTEsMapping = {}
 67
 68    with_ = root.expression.args.get("with")
 69    recursive = False
 70    if with_:
 71        recursive = with_.args.get("recursive")
 72        for cte in with_.expressions:
 73            existing_ctes[cte.this] = cte.alias
 74    new_ctes = []
 75
 76    # We're adding more CTEs, but we want to maintain the DAG order.
 77    # Derived tables within an existing CTE need to come before the existing CTE.
 78    for cte_scope in root.cte_scopes:
 79        # Append all the new CTEs from this existing CTE
 80        for scope in cte_scope.traverse():
 81            if scope is cte_scope:
 82                # Don't try to eliminate this CTE itself
 83                continue
 84            new_cte = _eliminate(scope, existing_ctes, taken)
 85            if new_cte:
 86                new_ctes.append(new_cte)
 87
 88        # Append the existing CTE itself
 89        new_ctes.append(cte_scope.expression.parent)
 90
 91    # Now append the rest
 92    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 93        for child_scope in scope.traverse():
 94            new_cte = _eliminate(child_scope, existing_ctes, taken)
 95            if new_cte:
 96                new_ctes.append(new_cte)
 97
 98    if new_ctes:
 99        query = expression.expression if isinstance(expression, exp.DDL) else expression
100        query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
101
102    return expression
103
104
105def _eliminate(
106    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
107) -> t.Optional[exp.Expression]:
108    if scope.is_derived_table:
109        return _eliminate_derived_table(scope, existing_ctes, taken)
110
111    if scope.is_cte:
112        return _eliminate_cte(scope, existing_ctes, taken)
113
114    return None
115
116
117def _eliminate_derived_table(
118    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
119) -> t.Optional[exp.Expression]:
120    # This makes sure that we don't:
121    # - drop the "pivot" arg from a pivoted subquery
122    # - eliminate a lateral correlated subquery
123    if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
124        return None
125
126    # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
127    to_replace = scope.expression.parent.unwrap()
128    name, cte = _new_cte(scope, existing_ctes, taken)
129    table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
130    table.set("joins", to_replace.args.get("joins"))
131
132    to_replace.replace(table)
133
134    return cte
135
136
137def _eliminate_cte(
138    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
139) -> t.Optional[exp.Expression]:
140    parent = scope.expression.parent
141    name, cte = _new_cte(scope, existing_ctes, taken)
142
143    with_ = parent.parent
144    parent.pop()
145    if not with_.expressions:
146        with_.pop()
147
148    # Rename references to this CTE
149    for child_scope in scope.parent.traverse():
150        for table, source in child_scope.selected_sources.values():
151            if source is scope:
152                new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
153                table.replace(new_table)
154
155    return cte
156
157
158def _new_cte(
159    scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
160) -> t.Tuple[str, t.Optional[exp.Expression]]:
161    """
162    Returns:
163        tuple of (name, cte)
164        where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance.
165        If this CTE duplicates an existing CTE, `cte` will be None.
166    """
167    duplicate_cte_alias = existing_ctes.get(scope.expression)
168    parent = scope.expression.parent
169    name = parent.alias
170
171    if not name:
172        name = find_new_name(taken=taken, base="cte")
173
174    if duplicate_cte_alias:
175        name = duplicate_cte_alias
176    elif taken.get(name):
177        name = find_new_name(taken=taken, base=name)
178
179    taken[name] = scope
180
181    if not duplicate_cte_alias:
182        existing_ctes[scope.expression] = name
183        cte = exp.CTE(
184            this=scope.expression,
185            alias=exp.TableAlias(this=exp.to_identifier(name)),
186        )
187    else:
188        cte = None
189    return name, cte
def eliminate_subqueries( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 16def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
 17    """
 18    Rewrite derived tables as CTES, deduplicating if possible.
 19
 20    Example:
 21        >>> import sqlglot
 22        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
 23        >>> eliminate_subqueries(expression).sql()
 24        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
 25
 26    This also deduplicates common subqueries:
 27        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
 28        >>> eliminate_subqueries(expression).sql()
 29        'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
 30
 31    Args:
 32        expression (sqlglot.Expression): expression
 33    Returns:
 34        sqlglot.Expression: expression
 35    """
 36    if isinstance(expression, exp.Subquery):
 37        # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
 38        eliminate_subqueries(expression.this)
 39        return expression
 40
 41    root = build_scope(expression)
 42
 43    if not root:
 44        return expression
 45
 46    # Map of alias->Scope|Table
 47    # These are all aliases that are already used in the expression.
 48    # We don't want to create new CTEs that conflict with these names.
 49    taken: TakenNameMapping = {}
 50
 51    # All CTE aliases in the root scope are taken
 52    for scope in root.cte_scopes:
 53        taken[scope.expression.parent.alias] = scope
 54
 55    # All table names are taken
 56    for scope in root.traverse():
 57        taken.update(
 58            {
 59                source.name: source
 60                for _, source in scope.sources.items()
 61                if isinstance(source, exp.Table)
 62            }
 63        )
 64
 65    # Map of Expression->alias
 66    # Existing CTES in the root expression. We'll use this for deduplication.
 67    existing_ctes: ExistingCTEsMapping = {}
 68
 69    with_ = root.expression.args.get("with")
 70    recursive = False
 71    if with_:
 72        recursive = with_.args.get("recursive")
 73        for cte in with_.expressions:
 74            existing_ctes[cte.this] = cte.alias
 75    new_ctes = []
 76
 77    # We're adding more CTEs, but we want to maintain the DAG order.
 78    # Derived tables within an existing CTE need to come before the existing CTE.
 79    for cte_scope in root.cte_scopes:
 80        # Append all the new CTEs from this existing CTE
 81        for scope in cte_scope.traverse():
 82            if scope is cte_scope:
 83                # Don't try to eliminate this CTE itself
 84                continue
 85            new_cte = _eliminate(scope, existing_ctes, taken)
 86            if new_cte:
 87                new_ctes.append(new_cte)
 88
 89        # Append the existing CTE itself
 90        new_ctes.append(cte_scope.expression.parent)
 91
 92    # Now append the rest
 93    for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes):
 94        for child_scope in scope.traverse():
 95            new_cte = _eliminate(child_scope, existing_ctes, taken)
 96            if new_cte:
 97                new_ctes.append(new_cte)
 98
 99    if new_ctes:
100        query = expression.expression if isinstance(expression, exp.DDL) else expression
101        query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
102
103    return expression

Rewrite derived tables as CTES, deduplicating if possible.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Arguments:
  • expression (sqlglot.Expression): expression
Returns:

sqlglot.Expression: expression