sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 if ( 45 not predicate 46 # Do not unnest subqueries inside table-valued functions such as 47 # FROM GENERATE_SERIES(...), FROM UNNEST(...) etc in order to preserve join order 48 or ( 49 isinstance(predicate, exp.Func) 50 and isinstance(predicate.parent, (exp.Table, exp.From, exp.Join)) 51 ) 52 or parent_select is not predicate.parent_select 53 or not parent_select.args.get("from_") 54 ): 55 return 56 57 if isinstance(select, exp.SetOperation): 58 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 59 60 alias = next_alias_name() 61 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 62 63 # This subquery returns a scalar and can just be converted to a cross join 64 if not isinstance(predicate, (exp.In, exp.Any)): 65 column = exp.column(select.selects[0].alias_or_name, alias) 66 67 clause_parent_select = clause.parent_select if clause else None 68 69 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 70 (not clause or clause_parent_select is not parent_select) 71 and ( 72 parent_select.args.get("group") 73 or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects) 74 ) 75 ): 76 column = exp.Max(this=column) 77 elif not isinstance(select.parent, exp.Subquery): 78 return 79 80 join_type = "CROSS" 81 on_clause = None 82 if isinstance(predicate, exp.Exists): 83 # If a subquery returns no rows, cross-joining against it incorrectly eliminates all rows 84 # from the parent query. Therefore, we use a LEFT JOIN that always matches (ON TRUE), then 85 # check for non-NULL column values to determine whether the subquery contained rows. 86 column = column.is_(exp.null()).not_() 87 join_type = "LEFT" 88 on_clause = exp.true() 89 90 _replace(select.parent, column) 91 parent_select.join(select, on=on_clause, join_type=join_type, join_alias=alias, copy=False) 92 93 return 94 95 if select.find(exp.Limit, exp.Offset): 96 return 97 98 if isinstance(predicate, exp.Any): 99 predicate = predicate.find_ancestor(exp.EQ) 100 101 if not predicate or parent_select is not predicate.parent_select: 102 return 103 104 column = _other_operand(predicate) 105 value = select.selects[0] 106 107 join_key = exp.column(value.alias, alias) 108 join_key_not_null = join_key.is_(exp.null()).not_() 109 110 if isinstance(clause, exp.Join): 111 _replace(predicate, exp.true()) 112 parent_select.where(join_key_not_null, copy=False) 113 else: 114 _replace(predicate, join_key_not_null) 115 116 group = select.args.get("group") 117 118 if group: 119 if {value.this} != set(group.expressions): 120 select = ( 121 exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) 122 .from_(select.subquery("_q", copy=False), copy=False) 123 .group_by(exp.column(value.alias, "_q"), copy=False) 124 ) 125 elif not find_in_scope(value.this, exp.AggFunc): 126 select = select.group_by(value.this, copy=False) 127 128 parent_select.join( 129 select, 130 on=column.eq(join_key), 131 join_type="LEFT", 132 join_alias=alias, 133 copy=False, 134 ) 135 136 137def decorrelate(select, parent_select, external_columns, next_alias_name): 138 where = select.args.get("where") 139 140 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 141 return 142 143 table_alias = next_alias_name() 144 keys = [] 145 146 # for all external columns in the where statement, find the relevant predicate 147 # keys to convert it into a join 148 for column in external_columns: 149 if column.find_ancestor(exp.Where) is not where: 150 return 151 152 predicate = column.find_ancestor(exp.Predicate) 153 154 if not predicate or predicate.find_ancestor(exp.Where) is not where: 155 return 156 157 if isinstance(predicate, exp.Binary): 158 key = ( 159 predicate.right 160 if any(node is column for node in predicate.left.walk()) 161 else predicate.left 162 ) 163 else: 164 return 165 166 keys.append((key, column, predicate)) 167 168 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 169 return 170 171 is_subquery_projection = any( 172 node is select.parent 173 for node in map(lambda s: s.unalias(), parent_select.selects) 174 if isinstance(node, exp.Subquery) 175 ) 176 177 value = select.selects[0] 178 key_aliases = {} 179 group_by = [] 180 181 for key, _, predicate in keys: 182 # if we filter on the value of the subquery, it needs to be unique 183 if key == value.this: 184 key_aliases[key] = value.alias 185 group_by.append(key) 186 else: 187 if key not in key_aliases: 188 key_aliases[key] = next_alias_name() 189 # all predicates that are equalities must also be in the unique 190 # so that we don't do a many to many join 191 if isinstance(predicate, exp.EQ) and key not in group_by: 192 group_by.append(key) 193 194 parent_predicate = select.find_ancestor(exp.Predicate) 195 196 # if the value of the subquery is not an agg or a key, we need to collect it into an array 197 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 198 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 199 if not value.find(exp.AggFunc) and value.this not in group_by: 200 select.select( 201 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 202 append=False, 203 copy=False, 204 ) 205 206 # exists queries should not have any selects as it only checks if there are any rows 207 # all selects will be added by the optimizer and only used for join keys 208 if isinstance(parent_predicate, exp.Exists): 209 select.set("expressions", []) 210 211 for key, alias in key_aliases.items(): 212 if key in group_by: 213 # add all keys to the projections of the subquery 214 # so that we can use it as a join key 215 if isinstance(parent_predicate, exp.Exists) or key != value.this: 216 select.select(f"{key} AS {alias}", copy=False) 217 else: 218 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 219 220 alias = exp.column(value.alias, table_alias) 221 other = _other_operand(parent_predicate) 222 op_type = type(parent_predicate.parent) if parent_predicate else None 223 224 if isinstance(parent_predicate, exp.Exists): 225 alias = exp.column(list(key_aliases.values())[0], table_alias) 226 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 227 elif isinstance(parent_predicate, exp.All): 228 assert issubclass(op_type, exp.Binary) 229 predicate = op_type(this=other, expression=exp.column("_x")) 230 parent_predicate = _replace( 231 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" 232 ) 233 elif isinstance(parent_predicate, exp.Any): 234 assert issubclass(op_type, exp.Binary) 235 if value.this in group_by: 236 predicate = op_type(this=other, expression=alias) 237 parent_predicate = _replace(parent_predicate.parent, predicate) 238 else: 239 predicate = op_type(this=other, expression=exp.column("_x")) 240 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})") 241 elif isinstance(parent_predicate, exp.In): 242 if value.this in group_by: 243 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 244 else: 245 parent_predicate = _replace( 246 parent_predicate, 247 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 248 ) 249 else: 250 if is_subquery_projection and select.parent.alias: 251 alias = exp.alias_(alias, select.parent.alias) 252 253 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 254 # by transforming all counts into 0 and using that as the coalesced value 255 if value.find(exp.Count): 256 257 def remove_aggs(node): 258 if isinstance(node, exp.Count): 259 return exp.Literal.number(0) 260 elif isinstance(node, exp.AggFunc): 261 return exp.null() 262 return node 263 264 alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)]) 265 266 select.parent.replace(alias) 267 268 for key, column, predicate in keys: 269 predicate.replace(exp.true()) 270 nested = exp.column(key_aliases[key], table_alias) 271 272 if is_subquery_projection: 273 key.replace(nested) 274 if not isinstance(predicate, exp.EQ): 275 parent_select.where(predicate, copy=False) 276 continue 277 278 if key in group_by: 279 key.replace(nested) 280 elif isinstance(predicate, exp.EQ): 281 parent_predicate = _replace( 282 parent_predicate, 283 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 284 ) 285 else: 286 key.replace(exp.to_identifier("_x")) 287 parent_predicate = _replace( 288 parent_predicate, 289 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 290 ) 291 292 parent_select.join( 293 select.group_by(*group_by, copy=False), 294 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 295 join_type="LEFT", 296 join_alias=table_alias, 297 copy=False, 298 ) 299 300 301def _replace(expression, condition): 302 return expression.replace(exp.condition(condition)) 303 304 305def _other_operand(expression): 306 if isinstance(expression, exp.In): 307 return expression.this 308 309 if isinstance(expression, (exp.Any, exp.All)): 310 return _other_operand(expression.parent) 311 312 if isinstance(expression, exp.Binary): 313 return ( 314 expression.right 315 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 316 else expression.left 317 ) 318 319 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 if ( 46 not predicate 47 # Do not unnest subqueries inside table-valued functions such as 48 # FROM GENERATE_SERIES(...), FROM UNNEST(...) etc in order to preserve join order 49 or ( 50 isinstance(predicate, exp.Func) 51 and isinstance(predicate.parent, (exp.Table, exp.From, exp.Join)) 52 ) 53 or parent_select is not predicate.parent_select 54 or not parent_select.args.get("from_") 55 ): 56 return 57 58 if isinstance(select, exp.SetOperation): 59 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 60 61 alias = next_alias_name() 62 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 63 64 # This subquery returns a scalar and can just be converted to a cross join 65 if not isinstance(predicate, (exp.In, exp.Any)): 66 column = exp.column(select.selects[0].alias_or_name, alias) 67 68 clause_parent_select = clause.parent_select if clause else None 69 70 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 71 (not clause or clause_parent_select is not parent_select) 72 and ( 73 parent_select.args.get("group") 74 or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects) 75 ) 76 ): 77 column = exp.Max(this=column) 78 elif not isinstance(select.parent, exp.Subquery): 79 return 80 81 join_type = "CROSS" 82 on_clause = None 83 if isinstance(predicate, exp.Exists): 84 # If a subquery returns no rows, cross-joining against it incorrectly eliminates all rows 85 # from the parent query. Therefore, we use a LEFT JOIN that always matches (ON TRUE), then 86 # check for non-NULL column values to determine whether the subquery contained rows. 87 column = column.is_(exp.null()).not_() 88 join_type = "LEFT" 89 on_clause = exp.true() 90 91 _replace(select.parent, column) 92 parent_select.join(select, on=on_clause, join_type=join_type, join_alias=alias, copy=False) 93 94 return 95 96 if select.find(exp.Limit, exp.Offset): 97 return 98 99 if isinstance(predicate, exp.Any): 100 predicate = predicate.find_ancestor(exp.EQ) 101 102 if not predicate or parent_select is not predicate.parent_select: 103 return 104 105 column = _other_operand(predicate) 106 value = select.selects[0] 107 108 join_key = exp.column(value.alias, alias) 109 join_key_not_null = join_key.is_(exp.null()).not_() 110 111 if isinstance(clause, exp.Join): 112 _replace(predicate, exp.true()) 113 parent_select.where(join_key_not_null, copy=False) 114 else: 115 _replace(predicate, join_key_not_null) 116 117 group = select.args.get("group") 118 119 if group: 120 if {value.this} != set(group.expressions): 121 select = ( 122 exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) 123 .from_(select.subquery("_q", copy=False), copy=False) 124 .group_by(exp.column(value.alias, "_q"), copy=False) 125 ) 126 elif not find_in_scope(value.this, exp.AggFunc): 127 select = select.group_by(value.this, copy=False) 128 129 parent_select.join( 130 select, 131 on=column.eq(join_key), 132 join_type="LEFT", 133 join_alias=alias, 134 copy=False, 135 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
138def decorrelate(select, parent_select, external_columns, next_alias_name): 139 where = select.args.get("where") 140 141 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 142 return 143 144 table_alias = next_alias_name() 145 keys = [] 146 147 # for all external columns in the where statement, find the relevant predicate 148 # keys to convert it into a join 149 for column in external_columns: 150 if column.find_ancestor(exp.Where) is not where: 151 return 152 153 predicate = column.find_ancestor(exp.Predicate) 154 155 if not predicate or predicate.find_ancestor(exp.Where) is not where: 156 return 157 158 if isinstance(predicate, exp.Binary): 159 key = ( 160 predicate.right 161 if any(node is column for node in predicate.left.walk()) 162 else predicate.left 163 ) 164 else: 165 return 166 167 keys.append((key, column, predicate)) 168 169 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 170 return 171 172 is_subquery_projection = any( 173 node is select.parent 174 for node in map(lambda s: s.unalias(), parent_select.selects) 175 if isinstance(node, exp.Subquery) 176 ) 177 178 value = select.selects[0] 179 key_aliases = {} 180 group_by = [] 181 182 for key, _, predicate in keys: 183 # if we filter on the value of the subquery, it needs to be unique 184 if key == value.this: 185 key_aliases[key] = value.alias 186 group_by.append(key) 187 else: 188 if key not in key_aliases: 189 key_aliases[key] = next_alias_name() 190 # all predicates that are equalities must also be in the unique 191 # so that we don't do a many to many join 192 if isinstance(predicate, exp.EQ) and key not in group_by: 193 group_by.append(key) 194 195 parent_predicate = select.find_ancestor(exp.Predicate) 196 197 # if the value of the subquery is not an agg or a key, we need to collect it into an array 198 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 199 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 200 if not value.find(exp.AggFunc) and value.this not in group_by: 201 select.select( 202 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 203 append=False, 204 copy=False, 205 ) 206 207 # exists queries should not have any selects as it only checks if there are any rows 208 # all selects will be added by the optimizer and only used for join keys 209 if isinstance(parent_predicate, exp.Exists): 210 select.set("expressions", []) 211 212 for key, alias in key_aliases.items(): 213 if key in group_by: 214 # add all keys to the projections of the subquery 215 # so that we can use it as a join key 216 if isinstance(parent_predicate, exp.Exists) or key != value.this: 217 select.select(f"{key} AS {alias}", copy=False) 218 else: 219 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 220 221 alias = exp.column(value.alias, table_alias) 222 other = _other_operand(parent_predicate) 223 op_type = type(parent_predicate.parent) if parent_predicate else None 224 225 if isinstance(parent_predicate, exp.Exists): 226 alias = exp.column(list(key_aliases.values())[0], table_alias) 227 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 228 elif isinstance(parent_predicate, exp.All): 229 assert issubclass(op_type, exp.Binary) 230 predicate = op_type(this=other, expression=exp.column("_x")) 231 parent_predicate = _replace( 232 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" 233 ) 234 elif isinstance(parent_predicate, exp.Any): 235 assert issubclass(op_type, exp.Binary) 236 if value.this in group_by: 237 predicate = op_type(this=other, expression=alias) 238 parent_predicate = _replace(parent_predicate.parent, predicate) 239 else: 240 predicate = op_type(this=other, expression=exp.column("_x")) 241 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})") 242 elif isinstance(parent_predicate, exp.In): 243 if value.this in group_by: 244 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 245 else: 246 parent_predicate = _replace( 247 parent_predicate, 248 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 249 ) 250 else: 251 if is_subquery_projection and select.parent.alias: 252 alias = exp.alias_(alias, select.parent.alias) 253 254 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 255 # by transforming all counts into 0 and using that as the coalesced value 256 if value.find(exp.Count): 257 258 def remove_aggs(node): 259 if isinstance(node, exp.Count): 260 return exp.Literal.number(0) 261 elif isinstance(node, exp.AggFunc): 262 return exp.null() 263 return node 264 265 alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)]) 266 267 select.parent.replace(alias) 268 269 for key, column, predicate in keys: 270 predicate.replace(exp.true()) 271 nested = exp.column(key_aliases[key], table_alias) 272 273 if is_subquery_projection: 274 key.replace(nested) 275 if not isinstance(predicate, exp.EQ): 276 parent_select.where(predicate, copy=False) 277 continue 278 279 if key in group_by: 280 key.replace(nested) 281 elif isinstance(predicate, exp.EQ): 282 parent_predicate = _replace( 283 parent_predicate, 284 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 285 ) 286 else: 287 key.replace(exp.to_identifier("_x")) 288 parent_predicate = _replace( 289 parent_predicate, 290 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 291 ) 292 293 parent_select.join( 294 select.group_by(*group_by, copy=False), 295 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 296 join_type="LEFT", 297 join_alias=table_alias, 298 copy=False, 299 )