Edit on GitHub

sqlglot.optimizer.unnest_subqueries

  1from sqlglot import exp
  2from sqlglot.helper import name_sequence
  3from sqlglot.optimizer.scope import ScopeType, traverse_scope
  4
  5
  6def unnest_subqueries(expression):
  7    """
  8    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
  9
 10    Convert scalar subqueries into cross joins.
 11    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
 12
 13    Example:
 14        >>> import sqlglot
 15        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
 16        >>> unnest_subqueries(expression).sql()
 17        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to unnest
 21    Returns:
 22        sqlglot.Expression: unnested expression
 23    """
 24    next_alias_name = name_sequence("_u_")
 25
 26    for scope in traverse_scope(expression):
 27        select = scope.expression
 28        parent = select.parent_select
 29        if not parent:
 30            continue
 31        if scope.external_columns:
 32            decorrelate(select, parent, scope.external_columns, next_alias_name)
 33        elif scope.scope_type == ScopeType.SUBQUERY:
 34            unnest(select, parent, next_alias_name)
 35
 36    return expression
 37
 38
 39def unnest(select, parent_select, next_alias_name):
 40    if len(select.selects) > 1:
 41        return
 42
 43    predicate = select.find_ancestor(exp.Condition)
 44    alias = next_alias_name()
 45
 46    if (
 47        not predicate
 48        or parent_select is not predicate.parent_select
 49        or not parent_select.args.get("from")
 50    ):
 51        return
 52
 53    clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 54
 55    # This subquery returns a scalar and can just be converted to a cross join
 56    if not isinstance(predicate, (exp.In, exp.Any)):
 57        column = exp.column(select.selects[0].alias_or_name, alias)
 58
 59        clause_parent_select = clause.parent_select if clause else None
 60
 61        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 62            (not clause or clause_parent_select is not parent_select)
 63            and (
 64                parent_select.args.get("group")
 65                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
 66            )
 67        ):
 68            column = exp.Max(this=column)
 69        elif not isinstance(select.parent, exp.Subquery):
 70            return
 71
 72        _replace(select.parent, column)
 73        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
 74        return
 75
 76    if select.find(exp.Limit, exp.Offset):
 77        return
 78
 79    if isinstance(predicate, exp.Any):
 80        predicate = predicate.find_ancestor(exp.EQ)
 81
 82        if not predicate or parent_select is not predicate.parent_select:
 83            return
 84
 85    column = _other_operand(predicate)
 86    value = select.selects[0]
 87
 88    join_key = exp.column(value.alias, alias)
 89    join_key_not_null = join_key.is_(exp.null()).not_()
 90
 91    if isinstance(clause, exp.Join):
 92        _replace(predicate, exp.true())
 93        parent_select.where(join_key_not_null, copy=False)
 94    else:
 95        _replace(predicate, join_key_not_null)
 96
 97    group = select.args.get("group")
 98
 99    if group:
100        if {value.this} != set(group.expressions):
101            select = (
102                exp.select(exp.column(value.alias, "_q"))
103                .from_(select.subquery("_q", copy=False), copy=False)
104                .group_by(exp.column(value.alias, "_q"), copy=False)
105            )
106    else:
107        select = select.group_by(value.this, copy=False)
108
109    parent_select.join(
110        select,
111        on=column.eq(join_key),
112        join_type="LEFT",
113        join_alias=alias,
114        copy=False,
115    )
116
117
118def decorrelate(select, parent_select, external_columns, next_alias_name):
119    where = select.args.get("where")
120
121    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
122        return
123
124    table_alias = next_alias_name()
125    keys = []
126
127    # for all external columns in the where statement, find the relevant predicate
128    # keys to convert it into a join
129    for column in external_columns:
130        if column.find_ancestor(exp.Where) is not where:
131            return
132
133        predicate = column.find_ancestor(exp.Predicate)
134
135        if not predicate or predicate.find_ancestor(exp.Where) is not where:
136            return
137
138        if isinstance(predicate, exp.Binary):
139            key = (
140                predicate.right
141                if any(node is column for node, *_ in predicate.left.walk())
142                else predicate.left
143            )
144        else:
145            return
146
147        keys.append((key, column, predicate))
148
149    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
150        return
151
152    is_subquery_projection = any(
153        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
154    )
155
156    value = select.selects[0]
157    key_aliases = {}
158    group_by = []
159
160    for key, _, predicate in keys:
161        # if we filter on the value of the subquery, it needs to be unique
162        if key == value.this:
163            key_aliases[key] = value.alias
164            group_by.append(key)
165        else:
166            if key not in key_aliases:
167                key_aliases[key] = next_alias_name()
168            # all predicates that are equalities must also be in the unique
169            # so that we don't do a many to many join
170            if isinstance(predicate, exp.EQ) and key not in group_by:
171                group_by.append(key)
172
173    parent_predicate = select.find_ancestor(exp.Predicate)
174
175    # if the value of the subquery is not an agg or a key, we need to collect it into an array
176    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
177    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
178    if not value.find(exp.AggFunc) and value.this not in group_by:
179        select.select(
180            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
181            append=False,
182            copy=False,
183        )
184
185    # exists queries should not have any selects as it only checks if there are any rows
186    # all selects will be added by the optimizer and only used for join keys
187    if isinstance(parent_predicate, exp.Exists):
188        select.args["expressions"] = []
189
190    for key, alias in key_aliases.items():
191        if key in group_by:
192            # add all keys to the projections of the subquery
193            # so that we can use it as a join key
194            if isinstance(parent_predicate, exp.Exists) or key != value.this:
195                select.select(f"{key} AS {alias}", copy=False)
196        else:
197            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
198
199    alias = exp.column(value.alias, table_alias)
200    other = _other_operand(parent_predicate)
201
202    if isinstance(parent_predicate, exp.Exists):
203        alias = exp.column(list(key_aliases.values())[0], table_alias)
204        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
205    elif isinstance(parent_predicate, exp.All):
206        parent_predicate = _replace(
207            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
208        )
209    elif isinstance(parent_predicate, exp.Any):
210        if value.this in group_by:
211            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
212        else:
213            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
214    elif isinstance(parent_predicate, exp.In):
215        if value.this in group_by:
216            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
217        else:
218            parent_predicate = _replace(
219                parent_predicate,
220                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
221            )
222    else:
223        if is_subquery_projection:
224            alias = exp.alias_(alias, select.parent.alias)
225
226        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
227        # by transforming all counts into 0 and using that as the coalesced value
228        if value.find(exp.Count):
229
230            def remove_aggs(node):
231                if isinstance(node, exp.Count):
232                    return exp.Literal.number(0)
233                elif isinstance(node, exp.AggFunc):
234                    return exp.null()
235                return node
236
237            alias = exp.Coalesce(
238                this=alias,
239                expressions=[value.this.transform(remove_aggs)],
240            )
241
242        select.parent.replace(alias)
243
244    for key, column, predicate in keys:
245        predicate.replace(exp.true())
246        nested = exp.column(key_aliases[key], table_alias)
247
248        if is_subquery_projection:
249            key.replace(nested)
250            continue
251
252        if key in group_by:
253            key.replace(nested)
254        elif isinstance(predicate, exp.EQ):
255            parent_predicate = _replace(
256                parent_predicate,
257                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
258            )
259        else:
260            key.replace(exp.to_identifier("_x"))
261            parent_predicate = _replace(
262                parent_predicate,
263                f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
264            )
265
266    parent_select.join(
267        select.group_by(*group_by, copy=False),
268        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
269        join_type="LEFT",
270        join_alias=table_alias,
271        copy=False,
272    )
273
274
275def _replace(expression, condition):
276    return expression.replace(exp.condition(condition))
277
278
279def _other_operand(expression):
280    if isinstance(expression, exp.In):
281        return expression.this
282
283    if isinstance(expression, (exp.Any, exp.All)):
284        return _other_operand(expression.parent)
285
286    if isinstance(expression, exp.Binary):
287        return (
288            expression.right
289            if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
290            else expression.left
291        )
292
293    return None
def unnest_subqueries(expression):
 7def unnest_subqueries(expression):
 8    """
 9    Rewrite sqlglot AST to convert some predicates with subqueries into joins.
10
11    Convert scalar subqueries into cross joins.
12    Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
13
14    Example:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
17        >>> unnest_subqueries(expression).sql()
18        'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
19
20    Args:
21        expression (sqlglot.Expression): expression to unnest
22    Returns:
23        sqlglot.Expression: unnested expression
24    """
25    next_alias_name = name_sequence("_u_")
26
27    for scope in traverse_scope(expression):
28        select = scope.expression
29        parent = select.parent_select
30        if not parent:
31            continue
32        if scope.external_columns:
33            decorrelate(select, parent, scope.external_columns, next_alias_name)
34        elif scope.scope_type == ScopeType.SUBQUERY:
35            unnest(select, parent, next_alias_name)
36
37    return expression

Rewrite sqlglot AST to convert some predicates with subqueries into joins.

Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
  • expression (sqlglot.Expression): expression to unnest
Returns:

sqlglot.Expression: unnested expression

def unnest(select, parent_select, next_alias_name):
 40def unnest(select, parent_select, next_alias_name):
 41    if len(select.selects) > 1:
 42        return
 43
 44    predicate = select.find_ancestor(exp.Condition)
 45    alias = next_alias_name()
 46
 47    if (
 48        not predicate
 49        or parent_select is not predicate.parent_select
 50        or not parent_select.args.get("from")
 51    ):
 52        return
 53
 54    clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
 55
 56    # This subquery returns a scalar and can just be converted to a cross join
 57    if not isinstance(predicate, (exp.In, exp.Any)):
 58        column = exp.column(select.selects[0].alias_or_name, alias)
 59
 60        clause_parent_select = clause.parent_select if clause else None
 61
 62        if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
 63            (not clause or clause_parent_select is not parent_select)
 64            and (
 65                parent_select.args.get("group")
 66                or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
 67            )
 68        ):
 69            column = exp.Max(this=column)
 70        elif not isinstance(select.parent, exp.Subquery):
 71            return
 72
 73        _replace(select.parent, column)
 74        parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
 75        return
 76
 77    if select.find(exp.Limit, exp.Offset):
 78        return
 79
 80    if isinstance(predicate, exp.Any):
 81        predicate = predicate.find_ancestor(exp.EQ)
 82
 83        if not predicate or parent_select is not predicate.parent_select:
 84            return
 85
 86    column = _other_operand(predicate)
 87    value = select.selects[0]
 88
 89    join_key = exp.column(value.alias, alias)
 90    join_key_not_null = join_key.is_(exp.null()).not_()
 91
 92    if isinstance(clause, exp.Join):
 93        _replace(predicate, exp.true())
 94        parent_select.where(join_key_not_null, copy=False)
 95    else:
 96        _replace(predicate, join_key_not_null)
 97
 98    group = select.args.get("group")
 99
100    if group:
101        if {value.this} != set(group.expressions):
102            select = (
103                exp.select(exp.column(value.alias, "_q"))
104                .from_(select.subquery("_q", copy=False), copy=False)
105                .group_by(exp.column(value.alias, "_q"), copy=False)
106            )
107    else:
108        select = select.group_by(value.this, copy=False)
109
110    parent_select.join(
111        select,
112        on=column.eq(join_key),
113        join_type="LEFT",
114        join_alias=alias,
115        copy=False,
116    )
def decorrelate(select, parent_select, external_columns, next_alias_name):
119def decorrelate(select, parent_select, external_columns, next_alias_name):
120    where = select.args.get("where")
121
122    if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
123        return
124
125    table_alias = next_alias_name()
126    keys = []
127
128    # for all external columns in the where statement, find the relevant predicate
129    # keys to convert it into a join
130    for column in external_columns:
131        if column.find_ancestor(exp.Where) is not where:
132            return
133
134        predicate = column.find_ancestor(exp.Predicate)
135
136        if not predicate or predicate.find_ancestor(exp.Where) is not where:
137            return
138
139        if isinstance(predicate, exp.Binary):
140            key = (
141                predicate.right
142                if any(node is column for node, *_ in predicate.left.walk())
143                else predicate.left
144            )
145        else:
146            return
147
148        keys.append((key, column, predicate))
149
150    if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
151        return
152
153    is_subquery_projection = any(
154        node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
155    )
156
157    value = select.selects[0]
158    key_aliases = {}
159    group_by = []
160
161    for key, _, predicate in keys:
162        # if we filter on the value of the subquery, it needs to be unique
163        if key == value.this:
164            key_aliases[key] = value.alias
165            group_by.append(key)
166        else:
167            if key not in key_aliases:
168                key_aliases[key] = next_alias_name()
169            # all predicates that are equalities must also be in the unique
170            # so that we don't do a many to many join
171            if isinstance(predicate, exp.EQ) and key not in group_by:
172                group_by.append(key)
173
174    parent_predicate = select.find_ancestor(exp.Predicate)
175
176    # if the value of the subquery is not an agg or a key, we need to collect it into an array
177    # so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
178    agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
179    if not value.find(exp.AggFunc) and value.this not in group_by:
180        select.select(
181            exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
182            append=False,
183            copy=False,
184        )
185
186    # exists queries should not have any selects as it only checks if there are any rows
187    # all selects will be added by the optimizer and only used for join keys
188    if isinstance(parent_predicate, exp.Exists):
189        select.args["expressions"] = []
190
191    for key, alias in key_aliases.items():
192        if key in group_by:
193            # add all keys to the projections of the subquery
194            # so that we can use it as a join key
195            if isinstance(parent_predicate, exp.Exists) or key != value.this:
196                select.select(f"{key} AS {alias}", copy=False)
197        else:
198            select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
199
200    alias = exp.column(value.alias, table_alias)
201    other = _other_operand(parent_predicate)
202
203    if isinstance(parent_predicate, exp.Exists):
204        alias = exp.column(list(key_aliases.values())[0], table_alias)
205        parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
206    elif isinstance(parent_predicate, exp.All):
207        parent_predicate = _replace(
208            parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
209        )
210    elif isinstance(parent_predicate, exp.Any):
211        if value.this in group_by:
212            parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
213        else:
214            parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
215    elif isinstance(parent_predicate, exp.In):
216        if value.this in group_by:
217            parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
218        else:
219            parent_predicate = _replace(
220                parent_predicate,
221                f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
222            )
223    else:
224        if is_subquery_projection:
225            alias = exp.alias_(alias, select.parent.alias)
226
227        # COUNT always returns 0 on empty datasets, so we need take that into consideration here
228        # by transforming all counts into 0 and using that as the coalesced value
229        if value.find(exp.Count):
230
231            def remove_aggs(node):
232                if isinstance(node, exp.Count):
233                    return exp.Literal.number(0)
234                elif isinstance(node, exp.AggFunc):
235                    return exp.null()
236                return node
237
238            alias = exp.Coalesce(
239                this=alias,
240                expressions=[value.this.transform(remove_aggs)],
241            )
242
243        select.parent.replace(alias)
244
245    for key, column, predicate in keys:
246        predicate.replace(exp.true())
247        nested = exp.column(key_aliases[key], table_alias)
248
249        if is_subquery_projection:
250            key.replace(nested)
251            continue
252
253        if key in group_by:
254            key.replace(nested)
255        elif isinstance(predicate, exp.EQ):
256            parent_predicate = _replace(
257                parent_predicate,
258                f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
259            )
260        else:
261            key.replace(exp.to_identifier("_x"))
262            parent_predicate = _replace(
263                parent_predicate,
264                f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
265            )
266
267    parent_select.join(
268        select.group_by(*group_by, copy=False),
269        on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
270        join_type="LEFT",
271        join_alias=table_alias,
272        copy=False,
273    )