sqlglot.optimizer.eliminate_subqueries
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.helper import find_new_name 8from sqlglot.optimizer.scope import Scope, build_scope 9 10if t.TYPE_CHECKING: 11 ExistingCTEsMapping = t.Dict[exp.Expression, str] 12 TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]] 13 14 15def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: 16 """ 17 Rewrite derived tables as CTES, deduplicating if possible. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") 22 >>> eliminate_subqueries(expression).sql() 23 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' 24 25 This also deduplicates common subqueries: 26 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") 27 >>> eliminate_subqueries(expression).sql() 28 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' 29 30 Args: 31 expression (sqlglot.Expression): expression 32 Returns: 33 sqlglot.Expression: expression 34 """ 35 if isinstance(expression, exp.Subquery): 36 # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 37 eliminate_subqueries(expression.this) 38 return expression 39 40 root = build_scope(expression) 41 42 if not root: 43 return expression 44 45 # Map of alias->Scope|Table 46 # These are all aliases that are already used in the expression. 47 # We don't want to create new CTEs that conflict with these names. 48 taken: TakenNameMapping = {} 49 50 # All CTE aliases in the root scope are taken 51 for scope in root.cte_scopes: 52 taken[scope.expression.parent.alias] = scope 53 54 # All table names are taken 55 for scope in root.traverse(): 56 taken.update( 57 { 58 source.name: source 59 for _, source in scope.sources.items() 60 if isinstance(source, exp.Table) 61 } 62 ) 63 64 # Map of Expression->alias 65 # Existing CTES in the root expression. We'll use this for deduplication. 66 existing_ctes: ExistingCTEsMapping = {} 67 68 with_ = root.expression.args.get("with") 69 recursive = False 70 if with_: 71 recursive = with_.args.get("recursive") 72 for cte in with_.expressions: 73 existing_ctes[cte.this] = cte.alias 74 new_ctes = [] 75 76 # We're adding more CTEs, but we want to maintain the DAG order. 77 # Derived tables within an existing CTE need to come before the existing CTE. 78 for cte_scope in root.cte_scopes: 79 # Append all the new CTEs from this existing CTE 80 for scope in cte_scope.traverse(): 81 if scope is cte_scope: 82 # Don't try to eliminate this CTE itself 83 continue 84 new_cte = _eliminate(scope, existing_ctes, taken) 85 if new_cte: 86 new_ctes.append(new_cte) 87 88 # Append the existing CTE itself 89 new_ctes.append(cte_scope.expression.parent) 90 91 # Now append the rest 92 for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes): 93 for child_scope in scope.traverse(): 94 new_cte = _eliminate(child_scope, existing_ctes, taken) 95 if new_cte: 96 new_ctes.append(new_cte) 97 98 if new_ctes: 99 query = expression.expression if isinstance(expression, exp.DDL) else expression 100 query.set("with", exp.With(expressions=new_ctes, recursive=recursive)) 101 102 return expression 103 104 105def _eliminate( 106 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 107) -> t.Optional[exp.Expression]: 108 if scope.is_derived_table: 109 return _eliminate_derived_table(scope, existing_ctes, taken) 110 111 if scope.is_cte: 112 return _eliminate_cte(scope, existing_ctes, taken) 113 114 return None 115 116 117def _eliminate_derived_table( 118 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 119) -> t.Optional[exp.Expression]: 120 # This makes sure that we don't: 121 # - drop the "pivot" arg from a pivoted subquery 122 # - eliminate a lateral correlated subquery 123 if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral): 124 return None 125 126 # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers 127 to_replace = scope.expression.parent.unwrap() 128 name, cte = _new_cte(scope, existing_ctes, taken) 129 table = exp.alias_(exp.table_(name), alias=to_replace.alias or name) 130 table.set("joins", to_replace.args.get("joins")) 131 132 to_replace.replace(table) 133 134 return cte 135 136 137def _eliminate_cte( 138 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 139) -> t.Optional[exp.Expression]: 140 parent = scope.expression.parent 141 name, cte = _new_cte(scope, existing_ctes, taken) 142 143 with_ = parent.parent 144 parent.pop() 145 if not with_.expressions: 146 with_.pop() 147 148 # Rename references to this CTE 149 for child_scope in scope.parent.traverse(): 150 for table, source in child_scope.selected_sources.values(): 151 if source is scope: 152 new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False) 153 table.replace(new_table) 154 155 return cte 156 157 158def _new_cte( 159 scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping 160) -> t.Tuple[str, t.Optional[exp.Expression]]: 161 """ 162 Returns: 163 tuple of (name, cte) 164 where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. 165 If this CTE duplicates an existing CTE, `cte` will be None. 166 """ 167 duplicate_cte_alias = existing_ctes.get(scope.expression) 168 parent = scope.expression.parent 169 name = parent.alias 170 171 if not name: 172 name = find_new_name(taken=taken, base="cte") 173 174 if duplicate_cte_alias: 175 name = duplicate_cte_alias 176 elif taken.get(name): 177 name = find_new_name(taken=taken, base=name) 178 179 taken[name] = scope 180 181 if not duplicate_cte_alias: 182 existing_ctes[scope.expression] = name 183 cte = exp.CTE( 184 this=scope.expression, 185 alias=exp.TableAlias(this=exp.to_identifier(name)), 186 ) 187 else: 188 cte = None 189 return name, cte
def
eliminate_subqueries( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
16def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: 17 """ 18 Rewrite derived tables as CTES, deduplicating if possible. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") 23 >>> eliminate_subqueries(expression).sql() 24 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' 25 26 This also deduplicates common subqueries: 27 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") 28 >>> eliminate_subqueries(expression).sql() 29 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' 30 31 Args: 32 expression (sqlglot.Expression): expression 33 Returns: 34 sqlglot.Expression: expression 35 """ 36 if isinstance(expression, exp.Subquery): 37 # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 38 eliminate_subqueries(expression.this) 39 return expression 40 41 root = build_scope(expression) 42 43 if not root: 44 return expression 45 46 # Map of alias->Scope|Table 47 # These are all aliases that are already used in the expression. 48 # We don't want to create new CTEs that conflict with these names. 49 taken: TakenNameMapping = {} 50 51 # All CTE aliases in the root scope are taken 52 for scope in root.cte_scopes: 53 taken[scope.expression.parent.alias] = scope 54 55 # All table names are taken 56 for scope in root.traverse(): 57 taken.update( 58 { 59 source.name: source 60 for _, source in scope.sources.items() 61 if isinstance(source, exp.Table) 62 } 63 ) 64 65 # Map of Expression->alias 66 # Existing CTES in the root expression. We'll use this for deduplication. 67 existing_ctes: ExistingCTEsMapping = {} 68 69 with_ = root.expression.args.get("with") 70 recursive = False 71 if with_: 72 recursive = with_.args.get("recursive") 73 for cte in with_.expressions: 74 existing_ctes[cte.this] = cte.alias 75 new_ctes = [] 76 77 # We're adding more CTEs, but we want to maintain the DAG order. 78 # Derived tables within an existing CTE need to come before the existing CTE. 79 for cte_scope in root.cte_scopes: 80 # Append all the new CTEs from this existing CTE 81 for scope in cte_scope.traverse(): 82 if scope is cte_scope: 83 # Don't try to eliminate this CTE itself 84 continue 85 new_cte = _eliminate(scope, existing_ctes, taken) 86 if new_cte: 87 new_ctes.append(new_cte) 88 89 # Append the existing CTE itself 90 new_ctes.append(cte_scope.expression.parent) 91 92 # Now append the rest 93 for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes): 94 for child_scope in scope.traverse(): 95 new_cte = _eliminate(child_scope, existing_ctes, taken) 96 if new_cte: 97 new_ctes.append(new_cte) 98 99 if new_ctes: 100 query = expression.expression if isinstance(expression, exp.DDL) else expression 101 query.set("with", exp.With(expressions=new_ctes, recursive=recursive)) 102 103 return expression
Rewrite derived tables as CTES, deduplicating if possible.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") >>> eliminate_subqueries(expression).sql() 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") >>> eliminate_subqueries(expression).sql() 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Arguments:
- expression (sqlglot.Expression): expression
Returns:
sqlglot.Expression: expression