sqlglot.optimizer.eliminate_subqueries
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.helper import find_new_name 8from sqlglot.optimizer.scope import Scope, build_scope 9 10if t.TYPE_CHECKING: 11 ExistingCTEsMapping = dict[exp.Expr, str] 12 TakenNameMapping = dict[str, t.Union[Scope, exp.Expr]] 13 14 15def eliminate_subqueries(expression: exp.Expr) -> exp.Expr: 16 """ 17 Rewrite derived tables as CTES, deduplicating if possible. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") 22 >>> eliminate_subqueries(expression).sql() 23 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' 24 25 This also deduplicates common subqueries: 26 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") 27 >>> eliminate_subqueries(expression).sql() 28 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' 29 30 Args: 31 expression (sqlglot.Expr): expression 32 Returns: 33 sqlglot.Expr: expression 34 """ 35 if isinstance(expression, exp.Subquery): 36 # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 37 eliminate_subqueries(expression.this) 38 return expression 39 40 root = build_scope(expression) 41 42 if not root: 43 return expression 44 45 # Map of alias->Scope|Table 46 # These are all aliases that are already used in the expression. 47 # We don't want to create new CTEs that conflict with these names. 48 taken: TakenNameMapping = {} 49 50 # All CTE aliases in the root scope are taken 51 for scope in root.cte_scopes: 52 parent = scope.expression.parent 53 if parent: 54 taken[parent.alias] = scope 55 56 # All table names are taken 57 for scope in root.traverse(): 58 taken.update( 59 { 60 source.name: source 61 for _, source in scope.sources.items() 62 if isinstance(source, exp.Table) 63 } 64 ) 65 66 # Map of Expr->alias 67 # Existing CTES in the root expression. We'll use this for deduplication. 68 existing_ctes: ExistingCTEsMapping = {} 69 70 with_ = root.expression.args.get("with_") 71 recursive = False 72 if with_: 73 recursive = with_.args.get("recursive") 74 for cte in with_.expressions: 75 existing_ctes[cte.this] = cte.alias 76 new_ctes = [] 77 78 # We're adding more CTEs, but we want to maintain the DAG order. 79 # Derived tables within an existing CTE need to come before the existing CTE. 80 for cte_scope in root.cte_scopes: 81 # Append all the new CTEs from this existing CTE 82 for scope in cte_scope.traverse(): 83 if scope is cte_scope: 84 # Don't try to eliminate this CTE itself 85 continue 86 new_cte = _eliminate(scope, existing_ctes, taken) 87 if new_cte: 88 new_ctes.append(new_cte) 89 90 # Append the existing CTE itself 91 cte_parent = cte_scope.expression.parent 92 if cte_parent: 93 new_ctes.append(cte_parent) 94 95 # Now append the rest 96 for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes): 97 for child_scope in scope.traverse(): 98 new_cte = _eliminate(child_scope, existing_ctes, taken) 99 if new_cte: 100 new_ctes.append(new_cte) 101 102 if new_ctes: 103 query = expression.expression if isinstance(expression, exp.DDL) else expression 104 query.set("with_", exp.With(expressions=new_ctes, recursive=recursive)) 105 106 return expression 107 108 109def _eliminate( 110 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 111) -> exp.Expr | None: 112 if scope.is_derived_table: 113 return _eliminate_derived_table(scope, existing_ctes, taken) 114 115 if scope.is_cte: 116 return _eliminate_cte(scope, existing_ctes, taken) 117 118 return None 119 120 121def _eliminate_derived_table( 122 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 123) -> exp.Expr | None: 124 # This makes sure that we don't: 125 # - drop the "pivot" arg from a pivoted subquery 126 # - eliminate a lateral correlated subquery 127 parent_scope = scope.parent 128 if not parent_scope or parent_scope.pivots or isinstance(parent_scope.expression, exp.Lateral): 129 return None 130 131 expr_parent = scope.expression.parent 132 if not isinstance(expr_parent, exp.Subquery): 133 return None 134 135 # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers 136 to_replace = expr_parent.unwrap() 137 name, cte = _new_cte(scope, existing_ctes, taken) 138 table = exp.alias_(exp.table_(name), alias=to_replace.alias or name) 139 table.set("joins", to_replace.args.get("joins")) 140 141 to_replace.replace(table) 142 143 return cte 144 145 146def _eliminate_cte( 147 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 148) -> exp.Expr | None: 149 parent = scope.expression.parent 150 if not parent: 151 return None 152 name, cte = _new_cte(scope, existing_ctes, taken) 153 154 with_ = parent.parent 155 parent.pop() 156 if with_ and not with_.expressions: 157 with_.pop() 158 159 # Rename references to this CTE 160 if not scope.parent: 161 return cte 162 for child_scope in scope.parent.traverse(): 163 for table, source in child_scope.selected_sources.values(): 164 if source is scope: 165 new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False) 166 table.replace(new_table) 167 168 return cte 169 170 171def _new_cte( 172 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 173) -> tuple[str, exp.Expr | None]: 174 """ 175 Returns: 176 tuple of (name, cte) 177 where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. 178 If this CTE duplicates an existing CTE, `cte` will be None. 179 """ 180 duplicate_cte_alias = existing_ctes.get(scope.expression) 181 parent = scope.expression.parent 182 name = parent.alias if parent else "" 183 184 if not name: 185 name = find_new_name(taken=taken, base="cte") 186 187 if duplicate_cte_alias: 188 name = duplicate_cte_alias 189 elif taken.get(name): 190 name = find_new_name(taken=taken, base=name) 191 192 taken[name] = scope 193 194 if not duplicate_cte_alias: 195 existing_ctes[scope.expression] = name 196 cte = exp.CTE( 197 this=scope.expression, 198 alias=exp.TableAlias(this=exp.to_identifier(name)), 199 ) 200 else: 201 cte = None 202 return name, cte
def
eliminate_subqueries( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
16def eliminate_subqueries(expression: exp.Expr) -> exp.Expr: 17 """ 18 Rewrite derived tables as CTES, deduplicating if possible. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") 23 >>> eliminate_subqueries(expression).sql() 24 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' 25 26 This also deduplicates common subqueries: 27 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") 28 >>> eliminate_subqueries(expression).sql() 29 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' 30 31 Args: 32 expression (sqlglot.Expr): expression 33 Returns: 34 sqlglot.Expr: expression 35 """ 36 if isinstance(expression, exp.Subquery): 37 # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 38 eliminate_subqueries(expression.this) 39 return expression 40 41 root = build_scope(expression) 42 43 if not root: 44 return expression 45 46 # Map of alias->Scope|Table 47 # These are all aliases that are already used in the expression. 48 # We don't want to create new CTEs that conflict with these names. 49 taken: TakenNameMapping = {} 50 51 # All CTE aliases in the root scope are taken 52 for scope in root.cte_scopes: 53 parent = scope.expression.parent 54 if parent: 55 taken[parent.alias] = scope 56 57 # All table names are taken 58 for scope in root.traverse(): 59 taken.update( 60 { 61 source.name: source 62 for _, source in scope.sources.items() 63 if isinstance(source, exp.Table) 64 } 65 ) 66 67 # Map of Expr->alias 68 # Existing CTES in the root expression. We'll use this for deduplication. 69 existing_ctes: ExistingCTEsMapping = {} 70 71 with_ = root.expression.args.get("with_") 72 recursive = False 73 if with_: 74 recursive = with_.args.get("recursive") 75 for cte in with_.expressions: 76 existing_ctes[cte.this] = cte.alias 77 new_ctes = [] 78 79 # We're adding more CTEs, but we want to maintain the DAG order. 80 # Derived tables within an existing CTE need to come before the existing CTE. 81 for cte_scope in root.cte_scopes: 82 # Append all the new CTEs from this existing CTE 83 for scope in cte_scope.traverse(): 84 if scope is cte_scope: 85 # Don't try to eliminate this CTE itself 86 continue 87 new_cte = _eliminate(scope, existing_ctes, taken) 88 if new_cte: 89 new_ctes.append(new_cte) 90 91 # Append the existing CTE itself 92 cte_parent = cte_scope.expression.parent 93 if cte_parent: 94 new_ctes.append(cte_parent) 95 96 # Now append the rest 97 for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes): 98 for child_scope in scope.traverse(): 99 new_cte = _eliminate(child_scope, existing_ctes, taken) 100 if new_cte: 101 new_ctes.append(new_cte) 102 103 if new_ctes: 104 query = expression.expression if isinstance(expression, exp.DDL) else expression 105 query.set("with_", exp.With(expressions=new_ctes, recursive=recursive)) 106 107 return expression
Rewrite derived tables as CTES, deduplicating if possible.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") >>> eliminate_subqueries(expression).sql() 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") >>> eliminate_subqueries(expression).sql() 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Arguments:
- expression (sqlglot.Expr): expression
Returns:
sqlglot.Expr: expression