Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1from sqlglot import exp
  2from sqlglot.helper import name_sequence
  3from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope
  4
  5
  6def unnest_subqueries(expression):
  7    """
  8    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
  9
 10    Convert scalar subqueries into cross joins.
 11    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 16        >>> unnest_subqueries(expression).sql()
 17        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to unnest
 21    Returns:
 22        sqlglot.Expression: unnested expression
 23    """
 24    next_alias_name = name_sequence("_u_")
 25
 26    for scope in traverse_scope(expression):
 27        select = scope.expression
 28        parent = select.parent_select
 29        if not parent:
 30            continue
 31        if scope.external_columns:
 32            decorrelate(select, parent, scope.external_columns, next_alias_name)
 33        elif scope.scope_type == ScopeType.SUBQUERY:
 34            unnest(select, parent, next_alias_name)
 35
 36    return expression
 37
 38
 39def unnest(select, parent_select, next_alias_name):
 40    if len(select.selects) > 1:
 41        return
 42
 43    predicate = select.find_ancestor(exp.Condition)
 44    if (
 45        not predicate
 46        # Do not unnest subqueries inside table-valued functions such as
 47        # FROM GENERATE_SERIES(...), FROM UNNEST(...) etc in order to preserve join order
 48        or (
 49            isinstance(predicate, exp.Func)
 50            and isinstance(predicate.parent, (exp.Table, exp.From, exp.Join))
 51        )
 52        or parent_select is not predicate.parent_select
 53        or not parent_select.args.get("from_")
 54    ):
 55        return
 56
 57    if isinstance(select, exp.SetOperation):
 58        select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
 59
 60    alias = next_alias_name()
 61    clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 62
 63    # This subquery returns a scalar and can just be converted to a cross join
 64    if not isinstance(predicate, (exp.In, exp.Any)):
 65        column = exp.column(select.selects[0].alias_or_name, alias)
 66
 67        clause_parent_select = clause.parent_select if clause else None
 68
 69        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 70            (not clause or clause_parent_select is not parent_select)
 71            and (
 72                parent_select.args.get("group")
 73                or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects)
 74            )
 75        ):
 76            column = exp.Max(this=column)
 77        elif not isinstance(select.parent, exp.Subquery):
 78            return
 79
 80        join_type = "CROSS"
 81        on_clause = None
 82        if isinstance(predicate, exp.Exists):
 83            # If a subquery returns no rows, cross-joining against it incorrectly eliminates all rows
 84            # from the parent query. Therefore, we use a LEFT JOIN that always matches (ON TRUE), then
 85            # check for non-NULL column values to determine whether the subquery contained rows.
 86            column = column.is_(exp.null()).not_()
 87            join_type = "LEFT"
 88            on_clause = exp.true()
 89
 90        _replace(select.parent, column)
 91        parent_select.join(select, on=on_clause, join_type=join_type, join_alias=alias, copy=False)
 92
 93        return
 94
 95    if select.find(exp.Limit, exp.Offset):
 96        return
 97
 98    if isinstance(predicate, exp.Any):
 99        predicate = predicate.find_ancestor(exp.EQ)
100
101        if not predicate or parent_select is not predicate.parent_select:
102            return
103
104    column = _other_operand(predicate)
105    value = select.selects[0]
106
107    join_key = exp.column(value.alias, alias)
108    join_key_not_null = join_key.is_(exp.null()).not_()
109
110    if isinstance(clause, exp.Join):
111        _replace(predicate, exp.true())
112        parent_select.where(join_key_not_null, copy=False)
113    else:
114        _replace(predicate, join_key_not_null)
115
116    group = select.args.get("group")
117
118    if group:
119        if {value.this} != set(group.expressions):
120            select = (
121                exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias))
122                .from_(select.subquery("_q", copy=False), copy=False)
123                .group_by(exp.column(value.alias, "_q"), copy=False)
124            )
125    elif not find_in_scope(value.this, exp.AggFunc):
126        select = select.group_by(value.this, copy=False)
127
128    parent_select.join(
129        select,
130        on=column.eq(join_key),
131        join_type="LEFT",
132        join_alias=alias,
133        copy=False,
134    )
135
136
137def decorrelate(select, parent_select, external_columns, next_alias_name):
138    where = select.args.get("where")
139
140    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
141        return
142
143    table_alias = next_alias_name()
144    keys = []
145
146    # for all external columns in the where statement, find the relevant predicate
147    # keys to convert it into a join
148    for column in external_columns:
149        if column.find_ancestor(exp.Where) is not where:
150            return
151
152        predicate = column.find_ancestor(exp.Predicate)
153
154        if not predicate or predicate.find_ancestor(exp.Where) is not where:
155            return
156
157        if isinstance(predicate, exp.Binary):
158            key = (
159                predicate.right
160                if any(node is column for node in predicate.left.walk())
161                else predicate.left
162            )
163        else:
164            return
165
166        keys.append((key, column, predicate))
167
168    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
169        return
170
171    is_subquery_projection = any(
172        node is select.parent
173        for node in map(lambda s: s.unalias(), parent_select.selects)
174        if isinstance(node, exp.Subquery)
175    )
176
177    value = select.selects[0]
178    key_aliases = {}
179    group_by = []
180
181    for key, _, predicate in keys:
182        # if we filter on the value of the subquery, it needs to be unique
183        if key == value.this:
184            key_aliases[key] = value.alias
185            group_by.append(key)
186        else:
187            if key not in key_aliases:
188                key_aliases[key] = next_alias_name()
189            # all predicates that are equalities must also be in the unique
190            # so that we don't do a many to many join
191            if isinstance(predicate, exp.EQ) and key not in group_by:
192                group_by.append(key)
193
194    parent_predicate = select.find_ancestor(exp.Predicate)
195
196    # if the value of the subquery is not an agg or a key, we need to collect it into an array
197    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
198    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
199    if not value.find(exp.AggFunc) and value.this not in group_by:
200        select.select(
201            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
202            append=False,
203            copy=False,
204        )
205
206    # exists queries should not have any selects as it only checks if there are any rows
207    # all selects will be added by the optimizer and only used for join keys
208    if isinstance(parent_predicate, exp.Exists):
209        select.set("expressions", [])
210
211    for key, alias in key_aliases.items():
212        if key in group_by:
213            # add all keys to the projections of the subquery
214            # so that we can use it as a join key
215            if isinstance(parent_predicate, exp.Exists) or key != value.this:
216                select.select(f"{key} AS {alias}", copy=False)
217        else:
218            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
219
220    alias = exp.column(value.alias, table_alias)
221    other = _other_operand(parent_predicate)
222    op_type = type(parent_predicate.parent) if parent_predicate else None
223
224    if isinstance(parent_predicate, exp.Exists):
225        alias = exp.column(list(key_aliases.values())[0], table_alias)
226        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
227    elif isinstance(parent_predicate, exp.All):
228        assert issubclass(op_type, exp.Binary)
229        predicate = op_type(this=other, expression=exp.column("_x"))
230        parent_predicate = _replace(
231            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})"
232        )
233    elif isinstance(parent_predicate, exp.Any):
234        assert issubclass(op_type, exp.Binary)
235        if value.this in group_by:
236            predicate = op_type(this=other, expression=alias)
237            parent_predicate = _replace(parent_predicate.parent, predicate)
238        else:
239            predicate = op_type(this=other, expression=exp.column("_x"))
240            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})")
241    elif isinstance(parent_predicate, exp.In):
242        if value.this in group_by:
243            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
244        else:
245            parent_predicate = _replace(
246                parent_predicate,
247                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
248            )
249    else:
250        if is_subquery_projection and select.parent.alias:
251            alias = exp.alias_(alias, select.parent.alias)
252
253        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
254        # by transforming all counts into 0 and using that as the coalesced value
255        if value.find(exp.Count):
256
257            def remove_aggs(node):
258                if isinstance(node, exp.Count):
259                    return exp.Literal.number(0)
260                elif isinstance(node, exp.AggFunc):
261                    return exp.null()
262                return node
263
264            alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)])
265
266        select.parent.replace(alias)
267
268    for key, column, predicate in keys:
269        predicate.replace(exp.true())
270        nested = exp.column(key_aliases[key], table_alias)
271
272        if is_subquery_projection:
273            key.replace(nested)
274            if not isinstance(predicate, exp.EQ):
275                parent_select.where(predicate, copy=False)
276            continue
277
278        if key in group_by:
279            key.replace(nested)
280        elif isinstance(predicate, exp.EQ):
281            parent_predicate = _replace(
282                parent_predicate,
283                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
284            )
285        else:
286            key.replace(exp.to_identifier("_x"))
287            parent_predicate = _replace(
288                parent_predicate,
289                f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
290            )
291
292    parent_select.join(
293        select.group_by(*group_by, copy=False),
294        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
295        join_type="LEFT",
296        join_alias=table_alias,
297        copy=False,
298    )
299
300
301def _replace(expression, condition):
302    return expression.replace(exp.condition(condition))
303
304
305def _other_operand(expression):
306    if isinstance(expression, exp.In):
307        return expression.this
308
309    if isinstance(expression, (exp.Any, exp.All)):
310        return _other_operand(expression.parent)
311
312    if isinstance(expression, exp.Binary):
313        return (
314            expression.right
315            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
316            else expression.left
317        )
318
319    return None
def unnest_subqueries(expression):
 7def unnest_subqueries(expression):
 8    """
 9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
10
11    Convert scalar subqueries into cross joins.
12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
17        >>> unnest_subqueries(expression).sql()
18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
19
20    Args:
21        expression (sqlglot.Expression): expression to unnest
22    Returns:
23        sqlglot.Expression: unnested expression
24    """
25    next_alias_name = name_sequence("_u_")
26
27    for scope in traverse_scope(expression):
28        select = scope.expression
29        parent = select.parent_select
30        if not parent:
31            continue
32        if scope.external_columns:
33            decorrelate(select, parent, scope.external_columns, next_alias_name)
34        elif scope.scope_type == ScopeType.SUBQUERY:
35            unnest(select, parent, next_alias_name)
36
37    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, next_alias_name):
 40def unnest(select, parent_select, next_alias_name):
 41    if len(select.selects) > 1:
 42        return
 43
 44    predicate = select.find_ancestor(exp.Condition)
 45    if (
 46        not predicate
 47        # Do not unnest subqueries inside table-valued functions such as
 48        # FROM GENERATE_SERIES(...), FROM UNNEST(...) etc in order to preserve join order
 49        or (
 50            isinstance(predicate, exp.Func)
 51            and isinstance(predicate.parent, (exp.Table, exp.From, exp.Join))
 52        )
 53        or parent_select is not predicate.parent_select
 54        or not parent_select.args.get("from_")
 55    ):
 56        return
 57
 58    if isinstance(select, exp.SetOperation):
 59        select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
 60
 61    alias = next_alias_name()
 62    clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 63
 64    # This subquery returns a scalar and can just be converted to a cross join
 65    if not isinstance(predicate, (exp.In, exp.Any)):
 66        column = exp.column(select.selects[0].alias_or_name, alias)
 67
 68        clause_parent_select = clause.parent_select if clause else None
 69
 70        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 71            (not clause or clause_parent_select is not parent_select)
 72            and (
 73                parent_select.args.get("group")
 74                or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects)
 75            )
 76        ):
 77            column = exp.Max(this=column)
 78        elif not isinstance(select.parent, exp.Subquery):
 79            return
 80
 81        join_type = "CROSS"
 82        on_clause = None
 83        if isinstance(predicate, exp.Exists):
 84            # If a subquery returns no rows, cross-joining against it incorrectly eliminates all rows
 85            # from the parent query. Therefore, we use a LEFT JOIN that always matches (ON TRUE), then
 86            # check for non-NULL column values to determine whether the subquery contained rows.
 87            column = column.is_(exp.null()).not_()
 88            join_type = "LEFT"
 89            on_clause = exp.true()
 90
 91        _replace(select.parent, column)
 92        parent_select.join(select, on=on_clause, join_type=join_type, join_alias=alias, copy=False)
 93
 94        return
 95
 96    if select.find(exp.Limit, exp.Offset):
 97        return
 98
 99    if isinstance(predicate, exp.Any):
100        predicate = predicate.find_ancestor(exp.EQ)
101
102        if not predicate or parent_select is not predicate.parent_select:
103            return
104
105    column = _other_operand(predicate)
106    value = select.selects[0]
107
108    join_key = exp.column(value.alias, alias)
109    join_key_not_null = join_key.is_(exp.null()).not_()
110
111    if isinstance(clause, exp.Join):
112        _replace(predicate, exp.true())
113        parent_select.where(join_key_not_null, copy=False)
114    else:
115        _replace(predicate, join_key_not_null)
116
117    group = select.args.get("group")
118
119    if group:
120        if {value.this} != set(group.expressions):
121            select = (
122                exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias))
123                .from_(select.subquery("_q", copy=False), copy=False)
124                .group_by(exp.column(value.alias, "_q"), copy=False)
125            )
126    elif not find_in_scope(value.this, exp.AggFunc):
127        select = select.group_by(value.this, copy=False)
128
129    parent_select.join(
130        select,
131        on=column.eq(join_key),
132        join_type="LEFT",
133        join_alias=alias,
134        copy=False,
135    )
def decorrelate(select, parent_select, external_columns, next_alias_name):
138def decorrelate(select, parent_select, external_columns, next_alias_name):
139    where = select.args.get("where")
140
141    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
142        return
143
144    table_alias = next_alias_name()
145    keys = []
146
147    # for all external columns in the where statement, find the relevant predicate
148    # keys to convert it into a join
149    for column in external_columns:
150        if column.find_ancestor(exp.Where) is not where:
151            return
152
153        predicate = column.find_ancestor(exp.Predicate)
154
155        if not predicate or predicate.find_ancestor(exp.Where) is not where:
156            return
157
158        if isinstance(predicate, exp.Binary):
159            key = (
160                predicate.right
161                if any(node is column for node in predicate.left.walk())
162                else predicate.left
163            )
164        else:
165            return
166
167        keys.append((key, column, predicate))
168
169    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
170        return
171
172    is_subquery_projection = any(
173        node is select.parent
174        for node in map(lambda s: s.unalias(), parent_select.selects)
175        if isinstance(node, exp.Subquery)
176    )
177
178    value = select.selects[0]
179    key_aliases = {}
180    group_by = []
181
182    for key, _, predicate in keys:
183        # if we filter on the value of the subquery, it needs to be unique
184        if key == value.this:
185            key_aliases[key] = value.alias
186            group_by.append(key)
187        else:
188            if key not in key_aliases:
189                key_aliases[key] = next_alias_name()
190            # all predicates that are equalities must also be in the unique
191            # so that we don't do a many to many join
192            if isinstance(predicate, exp.EQ) and key not in group_by:
193                group_by.append(key)
194
195    parent_predicate = select.find_ancestor(exp.Predicate)
196
197    # if the value of the subquery is not an agg or a key, we need to collect it into an array
198    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
199    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
200    if not value.find(exp.AggFunc) and value.this not in group_by:
201        select.select(
202            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
203            append=False,
204            copy=False,
205        )
206
207    # exists queries should not have any selects as it only checks if there are any rows
208    # all selects will be added by the optimizer and only used for join keys
209    if isinstance(parent_predicate, exp.Exists):
210        select.set("expressions", [])
211
212    for key, alias in key_aliases.items():
213        if key in group_by:
214            # add all keys to the projections of the subquery
215            # so that we can use it as a join key
216            if isinstance(parent_predicate, exp.Exists) or key != value.this:
217                select.select(f"{key} AS {alias}", copy=False)
218        else:
219            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
220
221    alias = exp.column(value.alias, table_alias)
222    other = _other_operand(parent_predicate)
223    op_type = type(parent_predicate.parent) if parent_predicate else None
224
225    if isinstance(parent_predicate, exp.Exists):
226        alias = exp.column(list(key_aliases.values())[0], table_alias)
227        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
228    elif isinstance(parent_predicate, exp.All):
229        assert issubclass(op_type, exp.Binary)
230        predicate = op_type(this=other, expression=exp.column("_x"))
231        parent_predicate = _replace(
232            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})"
233        )
234    elif isinstance(parent_predicate, exp.Any):
235        assert issubclass(op_type, exp.Binary)
236        if value.this in group_by:
237            predicate = op_type(this=other, expression=alias)
238            parent_predicate = _replace(parent_predicate.parent, predicate)
239        else:
240            predicate = op_type(this=other, expression=exp.column("_x"))
241            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})")
242    elif isinstance(parent_predicate, exp.In):
243        if value.this in group_by:
244            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
245        else:
246            parent_predicate = _replace(
247                parent_predicate,
248                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
249            )
250    else:
251        if is_subquery_projection and select.parent.alias:
252            alias = exp.alias_(alias, select.parent.alias)
253
254        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
255        # by transforming all counts into 0 and using that as the coalesced value
256        if value.find(exp.Count):
257
258            def remove_aggs(node):
259                if isinstance(node, exp.Count):
260                    return exp.Literal.number(0)
261                elif isinstance(node, exp.AggFunc):
262                    return exp.null()
263                return node
264
265            alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)])
266
267        select.parent.replace(alias)
268
269    for key, column, predicate in keys:
270        predicate.replace(exp.true())
271        nested = exp.column(key_aliases[key], table_alias)
272
273        if is_subquery_projection:
274            key.replace(nested)
275            if not isinstance(predicate, exp.EQ):
276                parent_select.where(predicate, copy=False)
277            continue
278
279        if key in group_by:
280            key.replace(nested)
281        elif isinstance(predicate, exp.EQ):
282            parent_predicate = _replace(
283                parent_predicate,
284                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
285            )
286        else:
287            key.replace(exp.to_identifier("_x"))
288            parent_predicate = _replace(
289                parent_predicate,
290                f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
291            )
292
293    parent_select.join(
294        select.group_by(*group_by, copy=False),
295        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
296        join_type="LEFT",
297        join_alias=table_alias,
298        copy=False,
299    )