Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1from sqlglot import exp
  2from sqlglot.helper import name_sequence
  3from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope
  4
  5
  6def unnest_subqueries(expression):
  7    """
  8    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
  9
 10    Convert scalar subqueries into cross joins.
 11    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 16        >>> unnest_subqueries(expression).sql()
 17        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to unnest
 21    Returns:
 22        sqlglot.Expression: unnested expression
 23    """
 24    next_alias_name = name_sequence("_u_")
 25
 26    for scope in traverse_scope(expression):
 27        select = scope.expression
 28        parent = select.parent_select
 29        if not parent:
 30            continue
 31        if scope.external_columns:
 32            decorrelate(select, parent, scope.external_columns, next_alias_name)
 33        elif scope.scope_type == ScopeType.SUBQUERY:
 34            unnest(select, parent, next_alias_name)
 35
 36    return expression
 37
 38
 39def unnest(select, parent_select, next_alias_name):
 40    if len(select.selects) > 1:
 41        return
 42
 43    predicate = select.find_ancestor(exp.Condition)
 44    if (
 45        not predicate
 46        or parent_select is not predicate.parent_select
 47        or not parent_select.args.get("from")
 48    ):
 49        return
 50
 51    if isinstance(select, exp.SetOperation):
 52        select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
 53
 54    alias = next_alias_name()
 55    clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 56
 57    # This subquery returns a scalar and can just be converted to a cross join
 58    if not isinstance(predicate, (exp.In, exp.Any)):
 59        column = exp.column(select.selects[0].alias_or_name, alias)
 60
 61        clause_parent_select = clause.parent_select if clause else None
 62
 63        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 64            (not clause or clause_parent_select is not parent_select)
 65            and (
 66                parent_select.args.get("group")
 67                or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects)
 68            )
 69        ):
 70            column = exp.Max(this=column)
 71        elif not isinstance(select.parent, exp.Subquery):
 72            return
 73
 74        _replace(select.parent, column)
 75        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
 76        return
 77
 78    if select.find(exp.Limit, exp.Offset):
 79        return
 80
 81    if isinstance(predicate, exp.Any):
 82        predicate = predicate.find_ancestor(exp.EQ)
 83
 84        if not predicate or parent_select is not predicate.parent_select:
 85            return
 86
 87    column = _other_operand(predicate)
 88    value = select.selects[0]
 89
 90    join_key = exp.column(value.alias, alias)
 91    join_key_not_null = join_key.is_(exp.null()).not_()
 92
 93    if isinstance(clause, exp.Join):
 94        _replace(predicate, exp.true())
 95        parent_select.where(join_key_not_null, copy=False)
 96    else:
 97        _replace(predicate, join_key_not_null)
 98
 99    group = select.args.get("group")
100
101    if group:
102        if {value.this} != set(group.expressions):
103            select = (
104                exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias))
105                .from_(select.subquery("_q", copy=False), copy=False)
106                .group_by(exp.column(value.alias, "_q"), copy=False)
107            )
108    else:
109        select = select.group_by(value.this, copy=False)
110
111    parent_select.join(
112        select,
113        on=column.eq(join_key),
114        join_type="LEFT",
115        join_alias=alias,
116        copy=False,
117    )
118
119
120def decorrelate(select, parent_select, external_columns, next_alias_name):
121    where = select.args.get("where")
122
123    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
124        return
125
126    table_alias = next_alias_name()
127    keys = []
128
129    # for all external columns in the where statement, find the relevant predicate
130    # keys to convert it into a join
131    for column in external_columns:
132        if column.find_ancestor(exp.Where) is not where:
133            return
134
135        predicate = column.find_ancestor(exp.Predicate)
136
137        if not predicate or predicate.find_ancestor(exp.Where) is not where:
138            return
139
140        if isinstance(predicate, exp.Binary):
141            key = (
142                predicate.right
143                if any(node is column for node in predicate.left.walk())
144                else predicate.left
145            )
146        else:
147            return
148
149        keys.append((key, column, predicate))
150
151    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
152        return
153
154    is_subquery_projection = any(
155        node is select.parent
156        for node in map(lambda s: s.unalias(), parent_select.selects)
157        if isinstance(node, exp.Subquery)
158    )
159
160    value = select.selects[0]
161    key_aliases = {}
162    group_by = []
163
164    for key, _, predicate in keys:
165        # if we filter on the value of the subquery, it needs to be unique
166        if key == value.this:
167            key_aliases[key] = value.alias
168            group_by.append(key)
169        else:
170            if key not in key_aliases:
171                key_aliases[key] = next_alias_name()
172            # all predicates that are equalities must also be in the unique
173            # so that we don't do a many to many join
174            if isinstance(predicate, exp.EQ) and key not in group_by:
175                group_by.append(key)
176
177    parent_predicate = select.find_ancestor(exp.Predicate)
178
179    # if the value of the subquery is not an agg or a key, we need to collect it into an array
180    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
181    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
182    if not value.find(exp.AggFunc) and value.this not in group_by:
183        select.select(
184            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
185            append=False,
186            copy=False,
187        )
188
189    # exists queries should not have any selects as it only checks if there are any rows
190    # all selects will be added by the optimizer and only used for join keys
191    if isinstance(parent_predicate, exp.Exists):
192        select.args["expressions"] = []
193
194    for key, alias in key_aliases.items():
195        if key in group_by:
196            # add all keys to the projections of the subquery
197            # so that we can use it as a join key
198            if isinstance(parent_predicate, exp.Exists) or key != value.this:
199                select.select(f"{key} AS {alias}", copy=False)
200        else:
201            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
202
203    alias = exp.column(value.alias, table_alias)
204    other = _other_operand(parent_predicate)
205    op_type = type(parent_predicate.parent) if parent_predicate else None
206
207    if isinstance(parent_predicate, exp.Exists):
208        alias = exp.column(list(key_aliases.values())[0], table_alias)
209        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
210    elif isinstance(parent_predicate, exp.All):
211        assert issubclass(op_type, exp.Binary)
212        predicate = op_type(this=other, expression=exp.column("_x"))
213        parent_predicate = _replace(
214            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})"
215        )
216    elif isinstance(parent_predicate, exp.Any):
217        assert issubclass(op_type, exp.Binary)
218        if value.this in group_by:
219            predicate = op_type(this=other, expression=alias)
220            parent_predicate = _replace(parent_predicate.parent, predicate)
221        else:
222            predicate = op_type(this=other, expression=exp.column("_x"))
223            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})")
224    elif isinstance(parent_predicate, exp.In):
225        if value.this in group_by:
226            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
227        else:
228            parent_predicate = _replace(
229                parent_predicate,
230                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
231            )
232    else:
233        if is_subquery_projection and select.parent.alias:
234            alias = exp.alias_(alias, select.parent.alias)
235
236        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
237        # by transforming all counts into 0 and using that as the coalesced value
238        if value.find(exp.Count):
239
240            def remove_aggs(node):
241                if isinstance(node, exp.Count):
242                    return exp.Literal.number(0)
243                elif isinstance(node, exp.AggFunc):
244                    return exp.null()
245                return node
246
247            alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)])
248
249        select.parent.replace(alias)
250
251    for key, column, predicate in keys:
252        predicate.replace(exp.true())
253        nested = exp.column(key_aliases[key], table_alias)
254
255        if is_subquery_projection:
256            key.replace(nested)
257            if not isinstance(predicate, exp.EQ):
258                parent_select.where(predicate, copy=False)
259            continue
260
261        if key in group_by:
262            key.replace(nested)
263        elif isinstance(predicate, exp.EQ):
264            parent_predicate = _replace(
265                parent_predicate,
266                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
267            )
268        else:
269            key.replace(exp.to_identifier("_x"))
270            parent_predicate = _replace(
271                parent_predicate,
272                f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
273            )
274
275    parent_select.join(
276        select.group_by(*group_by, copy=False),
277        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
278        join_type="LEFT",
279        join_alias=table_alias,
280        copy=False,
281    )
282
283
284def _replace(expression, condition):
285    return expression.replace(exp.condition(condition))
286
287
288def _other_operand(expression):
289    if isinstance(expression, exp.In):
290        return expression.this
291
292    if isinstance(expression, (exp.Any, exp.All)):
293        return _other_operand(expression.parent)
294
295    if isinstance(expression, exp.Binary):
296        return (
297            expression.right
298            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
299            else expression.left
300        )
301
302    return None
def unnest_subqueries(expression):
 7def unnest_subqueries(expression):
 8    """
 9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
10
11    Convert scalar subqueries into cross joins.
12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
17        >>> unnest_subqueries(expression).sql()
18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
19
20    Args:
21        expression (sqlglot.Expression): expression to unnest
22    Returns:
23        sqlglot.Expression: unnested expression
24    """
25    next_alias_name = name_sequence("_u_")
26
27    for scope in traverse_scope(expression):
28        select = scope.expression
29        parent = select.parent_select
30        if not parent:
31            continue
32        if scope.external_columns:
33            decorrelate(select, parent, scope.external_columns, next_alias_name)
34        elif scope.scope_type == ScopeType.SUBQUERY:
35            unnest(select, parent, next_alias_name)
36
37    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, next_alias_name):
 40def unnest(select, parent_select, next_alias_name):
 41    if len(select.selects) > 1:
 42        return
 43
 44    predicate = select.find_ancestor(exp.Condition)
 45    if (
 46        not predicate
 47        or parent_select is not predicate.parent_select
 48        or not parent_select.args.get("from")
 49    ):
 50        return
 51
 52    if isinstance(select, exp.SetOperation):
 53        select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
 54
 55    alias = next_alias_name()
 56    clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 57
 58    # This subquery returns a scalar and can just be converted to a cross join
 59    if not isinstance(predicate, (exp.In, exp.Any)):
 60        column = exp.column(select.selects[0].alias_or_name, alias)
 61
 62        clause_parent_select = clause.parent_select if clause else None
 63
 64        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 65            (not clause or clause_parent_select is not parent_select)
 66            and (
 67                parent_select.args.get("group")
 68                or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects)
 69            )
 70        ):
 71            column = exp.Max(this=column)
 72        elif not isinstance(select.parent, exp.Subquery):
 73            return
 74
 75        _replace(select.parent, column)
 76        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
 77        return
 78
 79    if select.find(exp.Limit, exp.Offset):
 80        return
 81
 82    if isinstance(predicate, exp.Any):
 83        predicate = predicate.find_ancestor(exp.EQ)
 84
 85        if not predicate or parent_select is not predicate.parent_select:
 86            return
 87
 88    column = _other_operand(predicate)
 89    value = select.selects[0]
 90
 91    join_key = exp.column(value.alias, alias)
 92    join_key_not_null = join_key.is_(exp.null()).not_()
 93
 94    if isinstance(clause, exp.Join):
 95        _replace(predicate, exp.true())
 96        parent_select.where(join_key_not_null, copy=False)
 97    else:
 98        _replace(predicate, join_key_not_null)
 99
100    group = select.args.get("group")
101
102    if group:
103        if {value.this} != set(group.expressions):
104            select = (
105                exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias))
106                .from_(select.subquery("_q", copy=False), copy=False)
107                .group_by(exp.column(value.alias, "_q"), copy=False)
108            )
109    else:
110        select = select.group_by(value.this, copy=False)
111
112    parent_select.join(
113        select,
114        on=column.eq(join_key),
115        join_type="LEFT",
116        join_alias=alias,
117        copy=False,
118    )
def decorrelate(select, parent_select, external_columns, next_alias_name):
121def decorrelate(select, parent_select, external_columns, next_alias_name):
122    where = select.args.get("where")
123
124    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
125        return
126
127    table_alias = next_alias_name()
128    keys = []
129
130    # for all external columns in the where statement, find the relevant predicate
131    # keys to convert it into a join
132    for column in external_columns:
133        if column.find_ancestor(exp.Where) is not where:
134            return
135
136        predicate = column.find_ancestor(exp.Predicate)
137
138        if not predicate or predicate.find_ancestor(exp.Where) is not where:
139            return
140
141        if isinstance(predicate, exp.Binary):
142            key = (
143                predicate.right
144                if any(node is column for node in predicate.left.walk())
145                else predicate.left
146            )
147        else:
148            return
149
150        keys.append((key, column, predicate))
151
152    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
153        return
154
155    is_subquery_projection = any(
156        node is select.parent
157        for node in map(lambda s: s.unalias(), parent_select.selects)
158        if isinstance(node, exp.Subquery)
159    )
160
161    value = select.selects[0]
162    key_aliases = {}
163    group_by = []
164
165    for key, _, predicate in keys:
166        # if we filter on the value of the subquery, it needs to be unique
167        if key == value.this:
168            key_aliases[key] = value.alias
169            group_by.append(key)
170        else:
171            if key not in key_aliases:
172                key_aliases[key] = next_alias_name()
173            # all predicates that are equalities must also be in the unique
174            # so that we don't do a many to many join
175            if isinstance(predicate, exp.EQ) and key not in group_by:
176                group_by.append(key)
177
178    parent_predicate = select.find_ancestor(exp.Predicate)
179
180    # if the value of the subquery is not an agg or a key, we need to collect it into an array
181    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
182    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
183    if not value.find(exp.AggFunc) and value.this not in group_by:
184        select.select(
185            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
186            append=False,
187            copy=False,
188        )
189
190    # exists queries should not have any selects as it only checks if there are any rows
191    # all selects will be added by the optimizer and only used for join keys
192    if isinstance(parent_predicate, exp.Exists):
193        select.args["expressions"] = []
194
195    for key, alias in key_aliases.items():
196        if key in group_by:
197            # add all keys to the projections of the subquery
198            # so that we can use it as a join key
199            if isinstance(parent_predicate, exp.Exists) or key != value.this:
200                select.select(f"{key} AS {alias}", copy=False)
201        else:
202            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
203
204    alias = exp.column(value.alias, table_alias)
205    other = _other_operand(parent_predicate)
206    op_type = type(parent_predicate.parent) if parent_predicate else None
207
208    if isinstance(parent_predicate, exp.Exists):
209        alias = exp.column(list(key_aliases.values())[0], table_alias)
210        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
211    elif isinstance(parent_predicate, exp.All):
212        assert issubclass(op_type, exp.Binary)
213        predicate = op_type(this=other, expression=exp.column("_x"))
214        parent_predicate = _replace(
215            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})"
216        )
217    elif isinstance(parent_predicate, exp.Any):
218        assert issubclass(op_type, exp.Binary)
219        if value.this in group_by:
220            predicate = op_type(this=other, expression=alias)
221            parent_predicate = _replace(parent_predicate.parent, predicate)
222        else:
223            predicate = op_type(this=other, expression=exp.column("_x"))
224            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})")
225    elif isinstance(parent_predicate, exp.In):
226        if value.this in group_by:
227            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
228        else:
229            parent_predicate = _replace(
230                parent_predicate,
231                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
232            )
233    else:
234        if is_subquery_projection and select.parent.alias:
235            alias = exp.alias_(alias, select.parent.alias)
236
237        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
238        # by transforming all counts into 0 and using that as the coalesced value
239        if value.find(exp.Count):
240
241            def remove_aggs(node):
242                if isinstance(node, exp.Count):
243                    return exp.Literal.number(0)
244                elif isinstance(node, exp.AggFunc):
245                    return exp.null()
246                return node
247
248            alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)])
249
250        select.parent.replace(alias)
251
252    for key, column, predicate in keys:
253        predicate.replace(exp.true())
254        nested = exp.column(key_aliases[key], table_alias)
255
256        if is_subquery_projection:
257            key.replace(nested)
258            if not isinstance(predicate, exp.EQ):
259                parent_select.where(predicate, copy=False)
260            continue
261
262        if key in group_by:
263            key.replace(nested)
264        elif isinstance(predicate, exp.EQ):
265            parent_predicate = _replace(
266                parent_predicate,
267                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
268            )
269        else:
270            key.replace(exp.to_identifier("_x"))
271            parent_predicate = _replace(
272                parent_predicate,
273                f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
274            )
275
276    parent_select.join(
277        select.group_by(*group_by, copy=False),
278        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
279        join_type="LEFT",
280        join_alias=table_alias,
281        copy=False,
282    )