sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 if ( 45 not predicate 46 or parent_select is not predicate.parent_select 47 or not parent_select.args.get("from") 48 ): 49 return 50 51 if isinstance(select, exp.SetOperation): 52 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 53 54 alias = next_alias_name() 55 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 56 57 # This subquery returns a scalar and can just be converted to a cross join 58 if not isinstance(predicate, (exp.In, exp.Any)): 59 column = exp.column(select.selects[0].alias_or_name, alias) 60 61 clause_parent_select = clause.parent_select if clause else None 62 63 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 64 (not clause or clause_parent_select is not parent_select) 65 and ( 66 parent_select.args.get("group") 67 or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects) 68 ) 69 ): 70 column = exp.Max(this=column) 71 elif not isinstance(select.parent, exp.Subquery): 72 return 73 74 _replace(select.parent, column) 75 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 76 return 77 78 if select.find(exp.Limit, exp.Offset): 79 return 80 81 if isinstance(predicate, exp.Any): 82 predicate = predicate.find_ancestor(exp.EQ) 83 84 if not predicate or parent_select is not predicate.parent_select: 85 return 86 87 column = _other_operand(predicate) 88 value = select.selects[0] 89 90 join_key = exp.column(value.alias, alias) 91 join_key_not_null = join_key.is_(exp.null()).not_() 92 93 if isinstance(clause, exp.Join): 94 _replace(predicate, exp.true()) 95 parent_select.where(join_key_not_null, copy=False) 96 else: 97 _replace(predicate, join_key_not_null) 98 99 group = select.args.get("group") 100 101 if group: 102 if {value.this} != set(group.expressions): 103 select = ( 104 exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) 105 .from_(select.subquery("_q", copy=False), copy=False) 106 .group_by(exp.column(value.alias, "_q"), copy=False) 107 ) 108 else: 109 select = select.group_by(value.this, copy=False) 110 111 parent_select.join( 112 select, 113 on=column.eq(join_key), 114 join_type="LEFT", 115 join_alias=alias, 116 copy=False, 117 ) 118 119 120def decorrelate(select, parent_select, external_columns, next_alias_name): 121 where = select.args.get("where") 122 123 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 124 return 125 126 table_alias = next_alias_name() 127 keys = [] 128 129 # for all external columns in the where statement, find the relevant predicate 130 # keys to convert it into a join 131 for column in external_columns: 132 if column.find_ancestor(exp.Where) is not where: 133 return 134 135 predicate = column.find_ancestor(exp.Predicate) 136 137 if not predicate or predicate.find_ancestor(exp.Where) is not where: 138 return 139 140 if isinstance(predicate, exp.Binary): 141 key = ( 142 predicate.right 143 if any(node is column for node in predicate.left.walk()) 144 else predicate.left 145 ) 146 else: 147 return 148 149 keys.append((key, column, predicate)) 150 151 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 152 return 153 154 is_subquery_projection = any( 155 node is select.parent 156 for node in map(lambda s: s.unalias(), parent_select.selects) 157 if isinstance(node, exp.Subquery) 158 ) 159 160 value = select.selects[0] 161 key_aliases = {} 162 group_by = [] 163 164 for key, _, predicate in keys: 165 # if we filter on the value of the subquery, it needs to be unique 166 if key == value.this: 167 key_aliases[key] = value.alias 168 group_by.append(key) 169 else: 170 if key not in key_aliases: 171 key_aliases[key] = next_alias_name() 172 # all predicates that are equalities must also be in the unique 173 # so that we don't do a many to many join 174 if isinstance(predicate, exp.EQ) and key not in group_by: 175 group_by.append(key) 176 177 parent_predicate = select.find_ancestor(exp.Predicate) 178 179 # if the value of the subquery is not an agg or a key, we need to collect it into an array 180 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 181 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 182 if not value.find(exp.AggFunc) and value.this not in group_by: 183 select.select( 184 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 185 append=False, 186 copy=False, 187 ) 188 189 # exists queries should not have any selects as it only checks if there are any rows 190 # all selects will be added by the optimizer and only used for join keys 191 if isinstance(parent_predicate, exp.Exists): 192 select.args["expressions"] = [] 193 194 for key, alias in key_aliases.items(): 195 if key in group_by: 196 # add all keys to the projections of the subquery 197 # so that we can use it as a join key 198 if isinstance(parent_predicate, exp.Exists) or key != value.this: 199 select.select(f"{key} AS {alias}", copy=False) 200 else: 201 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 202 203 alias = exp.column(value.alias, table_alias) 204 other = _other_operand(parent_predicate) 205 op_type = type(parent_predicate.parent) if parent_predicate else None 206 207 if isinstance(parent_predicate, exp.Exists): 208 alias = exp.column(list(key_aliases.values())[0], table_alias) 209 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 210 elif isinstance(parent_predicate, exp.All): 211 assert issubclass(op_type, exp.Binary) 212 predicate = op_type(this=other, expression=exp.column("_x")) 213 parent_predicate = _replace( 214 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" 215 ) 216 elif isinstance(parent_predicate, exp.Any): 217 assert issubclass(op_type, exp.Binary) 218 if value.this in group_by: 219 predicate = op_type(this=other, expression=alias) 220 parent_predicate = _replace(parent_predicate.parent, predicate) 221 else: 222 predicate = op_type(this=other, expression=exp.column("_x")) 223 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})") 224 elif isinstance(parent_predicate, exp.In): 225 if value.this in group_by: 226 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 227 else: 228 parent_predicate = _replace( 229 parent_predicate, 230 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 231 ) 232 else: 233 if is_subquery_projection and select.parent.alias: 234 alias = exp.alias_(alias, select.parent.alias) 235 236 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 237 # by transforming all counts into 0 and using that as the coalesced value 238 if value.find(exp.Count): 239 240 def remove_aggs(node): 241 if isinstance(node, exp.Count): 242 return exp.Literal.number(0) 243 elif isinstance(node, exp.AggFunc): 244 return exp.null() 245 return node 246 247 alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)]) 248 249 select.parent.replace(alias) 250 251 for key, column, predicate in keys: 252 predicate.replace(exp.true()) 253 nested = exp.column(key_aliases[key], table_alias) 254 255 if is_subquery_projection: 256 key.replace(nested) 257 if not isinstance(predicate, exp.EQ): 258 parent_select.where(predicate, copy=False) 259 continue 260 261 if key in group_by: 262 key.replace(nested) 263 elif isinstance(predicate, exp.EQ): 264 parent_predicate = _replace( 265 parent_predicate, 266 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 267 ) 268 else: 269 key.replace(exp.to_identifier("_x")) 270 parent_predicate = _replace( 271 parent_predicate, 272 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 273 ) 274 275 parent_select.join( 276 select.group_by(*group_by, copy=False), 277 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 278 join_type="LEFT", 279 join_alias=table_alias, 280 copy=False, 281 ) 282 283 284def _replace(expression, condition): 285 return expression.replace(exp.condition(condition)) 286 287 288def _other_operand(expression): 289 if isinstance(expression, exp.In): 290 return expression.this 291 292 if isinstance(expression, (exp.Any, exp.All)): 293 return _other_operand(expression.parent) 294 295 if isinstance(expression, exp.Binary): 296 return ( 297 expression.right 298 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 299 else expression.left 300 ) 301 302 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 if ( 46 not predicate 47 or parent_select is not predicate.parent_select 48 or not parent_select.args.get("from") 49 ): 50 return 51 52 if isinstance(select, exp.SetOperation): 53 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 54 55 alias = next_alias_name() 56 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 57 58 # This subquery returns a scalar and can just be converted to a cross join 59 if not isinstance(predicate, (exp.In, exp.Any)): 60 column = exp.column(select.selects[0].alias_or_name, alias) 61 62 clause_parent_select = clause.parent_select if clause else None 63 64 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 65 (not clause or clause_parent_select is not parent_select) 66 and ( 67 parent_select.args.get("group") 68 or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects) 69 ) 70 ): 71 column = exp.Max(this=column) 72 elif not isinstance(select.parent, exp.Subquery): 73 return 74 75 _replace(select.parent, column) 76 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 77 return 78 79 if select.find(exp.Limit, exp.Offset): 80 return 81 82 if isinstance(predicate, exp.Any): 83 predicate = predicate.find_ancestor(exp.EQ) 84 85 if not predicate or parent_select is not predicate.parent_select: 86 return 87 88 column = _other_operand(predicate) 89 value = select.selects[0] 90 91 join_key = exp.column(value.alias, alias) 92 join_key_not_null = join_key.is_(exp.null()).not_() 93 94 if isinstance(clause, exp.Join): 95 _replace(predicate, exp.true()) 96 parent_select.where(join_key_not_null, copy=False) 97 else: 98 _replace(predicate, join_key_not_null) 99 100 group = select.args.get("group") 101 102 if group: 103 if {value.this} != set(group.expressions): 104 select = ( 105 exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) 106 .from_(select.subquery("_q", copy=False), copy=False) 107 .group_by(exp.column(value.alias, "_q"), copy=False) 108 ) 109 else: 110 select = select.group_by(value.this, copy=False) 111 112 parent_select.join( 113 select, 114 on=column.eq(join_key), 115 join_type="LEFT", 116 join_alias=alias, 117 copy=False, 118 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
121def decorrelate(select, parent_select, external_columns, next_alias_name): 122 where = select.args.get("where") 123 124 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 125 return 126 127 table_alias = next_alias_name() 128 keys = [] 129 130 # for all external columns in the where statement, find the relevant predicate 131 # keys to convert it into a join 132 for column in external_columns: 133 if column.find_ancestor(exp.Where) is not where: 134 return 135 136 predicate = column.find_ancestor(exp.Predicate) 137 138 if not predicate or predicate.find_ancestor(exp.Where) is not where: 139 return 140 141 if isinstance(predicate, exp.Binary): 142 key = ( 143 predicate.right 144 if any(node is column for node in predicate.left.walk()) 145 else predicate.left 146 ) 147 else: 148 return 149 150 keys.append((key, column, predicate)) 151 152 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 153 return 154 155 is_subquery_projection = any( 156 node is select.parent 157 for node in map(lambda s: s.unalias(), parent_select.selects) 158 if isinstance(node, exp.Subquery) 159 ) 160 161 value = select.selects[0] 162 key_aliases = {} 163 group_by = [] 164 165 for key, _, predicate in keys: 166 # if we filter on the value of the subquery, it needs to be unique 167 if key == value.this: 168 key_aliases[key] = value.alias 169 group_by.append(key) 170 else: 171 if key not in key_aliases: 172 key_aliases[key] = next_alias_name() 173 # all predicates that are equalities must also be in the unique 174 # so that we don't do a many to many join 175 if isinstance(predicate, exp.EQ) and key not in group_by: 176 group_by.append(key) 177 178 parent_predicate = select.find_ancestor(exp.Predicate) 179 180 # if the value of the subquery is not an agg or a key, we need to collect it into an array 181 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 182 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 183 if not value.find(exp.AggFunc) and value.this not in group_by: 184 select.select( 185 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 186 append=False, 187 copy=False, 188 ) 189 190 # exists queries should not have any selects as it only checks if there are any rows 191 # all selects will be added by the optimizer and only used for join keys 192 if isinstance(parent_predicate, exp.Exists): 193 select.args["expressions"] = [] 194 195 for key, alias in key_aliases.items(): 196 if key in group_by: 197 # add all keys to the projections of the subquery 198 # so that we can use it as a join key 199 if isinstance(parent_predicate, exp.Exists) or key != value.this: 200 select.select(f"{key} AS {alias}", copy=False) 201 else: 202 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 203 204 alias = exp.column(value.alias, table_alias) 205 other = _other_operand(parent_predicate) 206 op_type = type(parent_predicate.parent) if parent_predicate else None 207 208 if isinstance(parent_predicate, exp.Exists): 209 alias = exp.column(list(key_aliases.values())[0], table_alias) 210 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 211 elif isinstance(parent_predicate, exp.All): 212 assert issubclass(op_type, exp.Binary) 213 predicate = op_type(this=other, expression=exp.column("_x")) 214 parent_predicate = _replace( 215 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" 216 ) 217 elif isinstance(parent_predicate, exp.Any): 218 assert issubclass(op_type, exp.Binary) 219 if value.this in group_by: 220 predicate = op_type(this=other, expression=alias) 221 parent_predicate = _replace(parent_predicate.parent, predicate) 222 else: 223 predicate = op_type(this=other, expression=exp.column("_x")) 224 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})") 225 elif isinstance(parent_predicate, exp.In): 226 if value.this in group_by: 227 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 228 else: 229 parent_predicate = _replace( 230 parent_predicate, 231 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 232 ) 233 else: 234 if is_subquery_projection and select.parent.alias: 235 alias = exp.alias_(alias, select.parent.alias) 236 237 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 238 # by transforming all counts into 0 and using that as the coalesced value 239 if value.find(exp.Count): 240 241 def remove_aggs(node): 242 if isinstance(node, exp.Count): 243 return exp.Literal.number(0) 244 elif isinstance(node, exp.AggFunc): 245 return exp.null() 246 return node 247 248 alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)]) 249 250 select.parent.replace(alias) 251 252 for key, column, predicate in keys: 253 predicate.replace(exp.true()) 254 nested = exp.column(key_aliases[key], table_alias) 255 256 if is_subquery_projection: 257 key.replace(nested) 258 if not isinstance(predicate, exp.EQ): 259 parent_select.where(predicate, copy=False) 260 continue 261 262 if key in group_by: 263 key.replace(nested) 264 elif isinstance(predicate, exp.EQ): 265 parent_predicate = _replace( 266 parent_predicate, 267 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 268 ) 269 else: 270 key.replace(exp.to_identifier("_x")) 271 parent_predicate = _replace( 272 parent_predicate, 273 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 274 ) 275 276 parent_select.join( 277 select.group_by(*group_by, copy=False), 278 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 279 join_type="LEFT", 280 join_alias=table_alias, 281 copy=False, 282 )