sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope 4 5 6def unnest_subqueries(expression: exp.Expr) -> exp.Expr: 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expr): expression to unnest 21 Returns: 22 sqlglot.Expr: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 if ( 45 not predicate 46 # Do not unnest subqueries inside table-valued functions such as 47 # FROM GENERATE_SERIES(...), FROM UNNEST(...) etc in order to preserve join order 48 or ( 49 isinstance(predicate, exp.Func) 50 and isinstance(predicate.parent, (exp.Table, exp.From, exp.Join)) 51 ) 52 or parent_select is not predicate.parent_select 53 or not parent_select.args.get("from_") 54 ): 55 return 56 57 if isinstance(select, exp.SetOperation): 58 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 59 60 alias = next_alias_name() 61 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 62 63 # This subquery returns a scalar and can just be converted to a cross join 64 if not isinstance(predicate, (exp.In, exp.Any)): 65 column = exp.column(select.selects[0].alias_or_name, alias) 66 67 clause_parent_select = clause.parent_select if clause else None 68 69 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 70 (not clause or clause_parent_select is not parent_select) 71 and ( 72 parent_select.args.get("group") 73 or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects) 74 ) 75 ): 76 column = exp.Max(this=column) 77 elif not isinstance(select.parent, exp.Subquery): 78 return 79 80 join_type = "CROSS" 81 on_clause = None 82 if isinstance(predicate, exp.Exists): 83 # If a subquery returns no rows, cross-joining against it incorrectly eliminates all rows 84 # from the parent query. Therefore, we use a LEFT JOIN that always matches (ON TRUE), then 85 # check for non-NULL column values to determine whether the subquery contained rows. 86 column = column.is_(exp.null()).not_() 87 join_type = "LEFT" 88 on_clause = exp.true() 89 90 _replace(select.parent, column) 91 parent_select.join(select, on=on_clause, join_type=join_type, join_alias=alias, copy=False) 92 93 return 94 95 if select.find(exp.Limit, exp.Offset): 96 return 97 98 if isinstance(predicate, exp.Any): 99 predicate = predicate.find_ancestor(exp.EQ) 100 101 if not predicate or parent_select is not predicate.parent_select: 102 return 103 104 column = _other_operand(predicate) 105 value = select.selects[0] 106 107 join_key = exp.column(value.alias, alias) 108 join_key_not_null = join_key.is_(exp.null()).not_() 109 110 if isinstance(clause, exp.Join): 111 _replace(predicate, exp.true()) 112 parent_select.where(join_key_not_null, copy=False) 113 else: 114 _replace(predicate, join_key_not_null) 115 116 group = select.args.get("group") 117 118 if group: 119 if {value.this} != set(group.expressions): 120 select = ( 121 exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) 122 .from_(select.subquery("_q", copy=False), copy=False) 123 .group_by(exp.column(value.alias, "_q"), copy=False) 124 ) 125 elif not find_in_scope(value.this, exp.AggFunc): 126 select = select.group_by(value.this, copy=False) 127 128 parent_select.join( 129 select, 130 on=column.eq(join_key), 131 join_type="LEFT", 132 join_alias=alias, 133 copy=False, 134 ) 135 136 137def decorrelate(select, parent_select, external_columns, next_alias_name): 138 where = select.args.get("where") 139 140 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 141 return 142 143 table_alias = next_alias_name() 144 keys = [] 145 146 # for all external columns in the where statement, find the relevant predicate 147 # keys to convert it into a join 148 for column in external_columns: 149 if column.find_ancestor(exp.Where) is not where: 150 return 151 152 predicate = column.find_ancestor(exp.Predicate) 153 154 if not predicate or predicate.find_ancestor(exp.Where) is not where: 155 return 156 157 if isinstance(predicate, exp.Binary): 158 key = ( 159 predicate.right 160 if any(node is column for node in predicate.left.walk()) 161 else predicate.left 162 ) 163 else: 164 return 165 166 keys.append((key, column, predicate)) 167 168 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 169 return 170 171 is_subquery_projection = any( 172 node is select.parent 173 for node in map(lambda s: s.unalias(), parent_select.selects) 174 if isinstance(node, exp.Subquery) 175 ) 176 177 value = select.selects[0] 178 key_aliases = {} 179 group_by = [] 180 181 for key, _, predicate in keys: 182 # if we filter on the value of the subquery, it needs to be unique 183 if key == value.this: 184 key_aliases[key] = value.alias 185 group_by.append(key) 186 else: 187 if key not in key_aliases: 188 key_aliases[key] = next_alias_name() 189 # all predicates that are equalities must also be in the unique 190 # so that we don't do a many to many join 191 if isinstance(predicate, exp.EQ) and key not in group_by: 192 group_by.append(key) 193 194 parent_predicate = select.find_ancestor(exp.Predicate) 195 196 # When the subquery is embedded inside a function (e.g. COALESCE, TRIM) in the SELECT list, 197 # the ancestor chain contains no Predicate node AND the subquery is not a direct projection. 198 if parent_predicate is None and not is_subquery_projection: 199 return 200 201 # if the value of the subquery is not an agg or a key, we need to collect it into an array 202 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 203 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 204 if not value.find(exp.AggFunc) and value.this not in group_by: 205 select.select( 206 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 207 append=False, 208 copy=False, 209 ) 210 211 # exists queries should not have any selects as it only checks if there are any rows 212 # all selects will be added by the optimizer and only used for join keys 213 if isinstance(parent_predicate, exp.Exists): 214 select.set("expressions", []) 215 216 for key, alias in key_aliases.items(): 217 if key in group_by: 218 # add all keys to the projections of the subquery 219 # so that we can use it as a join key 220 if isinstance(parent_predicate, exp.Exists) or key != value.this: 221 select.select(f"{key} AS {alias}", copy=False) 222 else: 223 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 224 225 alias = exp.column(value.alias, table_alias) 226 other = _other_operand(parent_predicate) 227 op_type = type(parent_predicate.parent) if parent_predicate else None 228 229 if isinstance(parent_predicate, exp.Exists): 230 alias = exp.column(list(key_aliases.values())[0], table_alias) 231 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 232 elif isinstance(parent_predicate, exp.All): 233 assert issubclass(op_type, exp.Binary) 234 predicate = op_type(this=other, expression=exp.column("_x")) 235 parent_predicate = _replace( 236 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" 237 ) 238 elif isinstance(parent_predicate, exp.Any): 239 assert issubclass(op_type, exp.Binary) 240 if value.this in group_by: 241 predicate = op_type(this=other, expression=alias) 242 parent_predicate = _replace(parent_predicate.parent, predicate) 243 else: 244 predicate = op_type(this=other, expression=exp.column("_x")) 245 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})") 246 elif isinstance(parent_predicate, exp.In): 247 if value.this in group_by: 248 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 249 else: 250 parent_predicate = _replace( 251 parent_predicate, 252 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 253 ) 254 else: 255 if is_subquery_projection and select.parent.alias: 256 alias = exp.alias_(alias, select.parent.alias) 257 258 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 259 # by transforming all counts into 0 and using that as the coalesced value 260 if value.find(exp.Count): 261 262 def remove_aggs(node): 263 if isinstance(node, exp.Count): 264 return exp.Literal.number(0) 265 elif isinstance(node, exp.AggFunc): 266 return exp.null() 267 return node 268 269 alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)]) 270 271 select.parent.replace(alias) 272 273 for key, column, predicate in keys: 274 predicate.replace(exp.true()) 275 nested = exp.column(key_aliases[key], table_alias) 276 277 if is_subquery_projection: 278 key.replace(nested) 279 if not isinstance(predicate, exp.EQ): 280 parent_select.where(predicate, copy=False) 281 continue 282 283 if key in group_by: 284 key.replace(nested) 285 elif isinstance(predicate, exp.EQ): 286 parent_predicate = _replace( 287 parent_predicate, 288 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 289 ) 290 else: 291 key.replace(exp.to_identifier("_x")) 292 parent_predicate = _replace( 293 parent_predicate, 294 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 295 ) 296 297 parent_select.join( 298 select.group_by(*group_by, copy=False), 299 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 300 join_type="LEFT", 301 join_alias=table_alias, 302 copy=False, 303 ) 304 305 306def _replace(expression, condition): 307 return expression.replace(exp.condition(condition)) 308 309 310def _other_operand(expression): 311 if isinstance(expression, exp.In): 312 return expression.this 313 314 if isinstance(expression, (exp.Any, exp.All)): 315 return _other_operand(expression.parent) 316 317 if isinstance(expression, exp.Binary): 318 return ( 319 expression.right 320 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 321 else expression.left 322 ) 323 324 return None
7def unnest_subqueries(expression: exp.Expr) -> exp.Expr: 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expr): expression to unnest 22 Returns: 23 sqlglot.Expr: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expr): expression to unnest
Returns:
sqlglot.Expr: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 if ( 46 not predicate 47 # Do not unnest subqueries inside table-valued functions such as 48 # FROM GENERATE_SERIES(...), FROM UNNEST(...) etc in order to preserve join order 49 or ( 50 isinstance(predicate, exp.Func) 51 and isinstance(predicate.parent, (exp.Table, exp.From, exp.Join)) 52 ) 53 or parent_select is not predicate.parent_select 54 or not parent_select.args.get("from_") 55 ): 56 return 57 58 if isinstance(select, exp.SetOperation): 59 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 60 61 alias = next_alias_name() 62 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 63 64 # This subquery returns a scalar and can just be converted to a cross join 65 if not isinstance(predicate, (exp.In, exp.Any)): 66 column = exp.column(select.selects[0].alias_or_name, alias) 67 68 clause_parent_select = clause.parent_select if clause else None 69 70 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 71 (not clause or clause_parent_select is not parent_select) 72 and ( 73 parent_select.args.get("group") 74 or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects) 75 ) 76 ): 77 column = exp.Max(this=column) 78 elif not isinstance(select.parent, exp.Subquery): 79 return 80 81 join_type = "CROSS" 82 on_clause = None 83 if isinstance(predicate, exp.Exists): 84 # If a subquery returns no rows, cross-joining against it incorrectly eliminates all rows 85 # from the parent query. Therefore, we use a LEFT JOIN that always matches (ON TRUE), then 86 # check for non-NULL column values to determine whether the subquery contained rows. 87 column = column.is_(exp.null()).not_() 88 join_type = "LEFT" 89 on_clause = exp.true() 90 91 _replace(select.parent, column) 92 parent_select.join(select, on=on_clause, join_type=join_type, join_alias=alias, copy=False) 93 94 return 95 96 if select.find(exp.Limit, exp.Offset): 97 return 98 99 if isinstance(predicate, exp.Any): 100 predicate = predicate.find_ancestor(exp.EQ) 101 102 if not predicate or parent_select is not predicate.parent_select: 103 return 104 105 column = _other_operand(predicate) 106 value = select.selects[0] 107 108 join_key = exp.column(value.alias, alias) 109 join_key_not_null = join_key.is_(exp.null()).not_() 110 111 if isinstance(clause, exp.Join): 112 _replace(predicate, exp.true()) 113 parent_select.where(join_key_not_null, copy=False) 114 else: 115 _replace(predicate, join_key_not_null) 116 117 group = select.args.get("group") 118 119 if group: 120 if {value.this} != set(group.expressions): 121 select = ( 122 exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) 123 .from_(select.subquery("_q", copy=False), copy=False) 124 .group_by(exp.column(value.alias, "_q"), copy=False) 125 ) 126 elif not find_in_scope(value.this, exp.AggFunc): 127 select = select.group_by(value.this, copy=False) 128 129 parent_select.join( 130 select, 131 on=column.eq(join_key), 132 join_type="LEFT", 133 join_alias=alias, 134 copy=False, 135 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
138def decorrelate(select, parent_select, external_columns, next_alias_name): 139 where = select.args.get("where") 140 141 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 142 return 143 144 table_alias = next_alias_name() 145 keys = [] 146 147 # for all external columns in the where statement, find the relevant predicate 148 # keys to convert it into a join 149 for column in external_columns: 150 if column.find_ancestor(exp.Where) is not where: 151 return 152 153 predicate = column.find_ancestor(exp.Predicate) 154 155 if not predicate or predicate.find_ancestor(exp.Where) is not where: 156 return 157 158 if isinstance(predicate, exp.Binary): 159 key = ( 160 predicate.right 161 if any(node is column for node in predicate.left.walk()) 162 else predicate.left 163 ) 164 else: 165 return 166 167 keys.append((key, column, predicate)) 168 169 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 170 return 171 172 is_subquery_projection = any( 173 node is select.parent 174 for node in map(lambda s: s.unalias(), parent_select.selects) 175 if isinstance(node, exp.Subquery) 176 ) 177 178 value = select.selects[0] 179 key_aliases = {} 180 group_by = [] 181 182 for key, _, predicate in keys: 183 # if we filter on the value of the subquery, it needs to be unique 184 if key == value.this: 185 key_aliases[key] = value.alias 186 group_by.append(key) 187 else: 188 if key not in key_aliases: 189 key_aliases[key] = next_alias_name() 190 # all predicates that are equalities must also be in the unique 191 # so that we don't do a many to many join 192 if isinstance(predicate, exp.EQ) and key not in group_by: 193 group_by.append(key) 194 195 parent_predicate = select.find_ancestor(exp.Predicate) 196 197 # When the subquery is embedded inside a function (e.g. COALESCE, TRIM) in the SELECT list, 198 # the ancestor chain contains no Predicate node AND the subquery is not a direct projection. 199 if parent_predicate is None and not is_subquery_projection: 200 return 201 202 # if the value of the subquery is not an agg or a key, we need to collect it into an array 203 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 204 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 205 if not value.find(exp.AggFunc) and value.this not in group_by: 206 select.select( 207 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 208 append=False, 209 copy=False, 210 ) 211 212 # exists queries should not have any selects as it only checks if there are any rows 213 # all selects will be added by the optimizer and only used for join keys 214 if isinstance(parent_predicate, exp.Exists): 215 select.set("expressions", []) 216 217 for key, alias in key_aliases.items(): 218 if key in group_by: 219 # add all keys to the projections of the subquery 220 # so that we can use it as a join key 221 if isinstance(parent_predicate, exp.Exists) or key != value.this: 222 select.select(f"{key} AS {alias}", copy=False) 223 else: 224 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 225 226 alias = exp.column(value.alias, table_alias) 227 other = _other_operand(parent_predicate) 228 op_type = type(parent_predicate.parent) if parent_predicate else None 229 230 if isinstance(parent_predicate, exp.Exists): 231 alias = exp.column(list(key_aliases.values())[0], table_alias) 232 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 233 elif isinstance(parent_predicate, exp.All): 234 assert issubclass(op_type, exp.Binary) 235 predicate = op_type(this=other, expression=exp.column("_x")) 236 parent_predicate = _replace( 237 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" 238 ) 239 elif isinstance(parent_predicate, exp.Any): 240 assert issubclass(op_type, exp.Binary) 241 if value.this in group_by: 242 predicate = op_type(this=other, expression=alias) 243 parent_predicate = _replace(parent_predicate.parent, predicate) 244 else: 245 predicate = op_type(this=other, expression=exp.column("_x")) 246 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})") 247 elif isinstance(parent_predicate, exp.In): 248 if value.this in group_by: 249 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 250 else: 251 parent_predicate = _replace( 252 parent_predicate, 253 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 254 ) 255 else: 256 if is_subquery_projection and select.parent.alias: 257 alias = exp.alias_(alias, select.parent.alias) 258 259 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 260 # by transforming all counts into 0 and using that as the coalesced value 261 if value.find(exp.Count): 262 263 def remove_aggs(node): 264 if isinstance(node, exp.Count): 265 return exp.Literal.number(0) 266 elif isinstance(node, exp.AggFunc): 267 return exp.null() 268 return node 269 270 alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)]) 271 272 select.parent.replace(alias) 273 274 for key, column, predicate in keys: 275 predicate.replace(exp.true()) 276 nested = exp.column(key_aliases[key], table_alias) 277 278 if is_subquery_projection: 279 key.replace(nested) 280 if not isinstance(predicate, exp.EQ): 281 parent_select.where(predicate, copy=False) 282 continue 283 284 if key in group_by: 285 key.replace(nested) 286 elif isinstance(predicate, exp.EQ): 287 parent_predicate = _replace( 288 parent_predicate, 289 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 290 ) 291 else: 292 key.replace(exp.to_identifier("_x")) 293 parent_predicate = _replace( 294 parent_predicate, 295 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 296 ) 297 298 parent_select.join( 299 select.group_by(*group_by, copy=False), 300 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 301 join_type="LEFT", 302 join_alias=table_alias, 303 copy=False, 304 )