Edit on GitHub

sqlglot.optimizer.merge_subqueries

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from collections import defaultdict
  6
  7from sqlglot import expressions as exp
  8from sqlglot.helper import find_new_name
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot._typing import E
 13
 14    FromOrJoin = t.Union[exp.From, exp.Join]
 15
 16
 17def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
 18    """
 19    Rewrite sqlglot AST to merge derived tables into the outer query.
 20
 21    This also merges CTEs if they are selected from only once.
 22
 23    Example:
 24        >>> import sqlglot
 25        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 26        >>> merge_subqueries(expression).sql()
 27        'SELECT x.a FROM x CROSS JOIN y'
 28
 29    If `leave_tables_isolated` is True, this will not merge inner queries into outer
 30    queries if it would result in multiple table selects in a single query:
 31        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 32        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
 33        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
 34
 35    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
 36
 37    Args:
 38        expression (sqlglot.Expression): expression to optimize
 39        leave_tables_isolated (bool):
 40    Returns:
 41        sqlglot.Expression: optimized expression
 42    """
 43    expression = merge_ctes(expression, leave_tables_isolated)
 44    expression = merge_derived_tables(expression, leave_tables_isolated)
 45    return expression
 46
 47
 48# If a derived table has these Select args, it can't be merged
 49UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
 50    "expressions",
 51    "from",
 52    "joins",
 53    "where",
 54    "order",
 55    "hint",
 56}
 57
 58
 59# Projections in the outer query that are instances of these types can be replaced
 60# without getting wrapped in parentheses, because the precedence won't be altered.
 61SAFE_TO_REPLACE_UNWRAPPED = (
 62    exp.Column,
 63    exp.EQ,
 64    exp.Func,
 65    exp.NEQ,
 66    exp.Paren,
 67)
 68
 69
 70def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 71    scopes = traverse_scope(expression)
 72
 73    # All places where we select from CTEs.
 74    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 75    cte_selections = defaultdict(list)
 76    for outer_scope in scopes:
 77        for table, inner_scope in outer_scope.selected_sources.values():
 78            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 79                cte_selections[id(inner_scope)].append(
 80                    (
 81                        outer_scope,
 82                        inner_scope,
 83                        table,
 84                    )
 85                )
 86
 87    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 88    for outer_scope, inner_scope, table in singular_cte_selections:
 89        from_or_join = table.find_ancestor(exp.From, exp.Join)
 90        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 91            alias = table.alias_or_name
 92            _rename_inner_sources(outer_scope, inner_scope, alias)
 93            _merge_from(outer_scope, inner_scope, table, alias)
 94            _merge_expressions(outer_scope, inner_scope, alias)
 95            _merge_joins(outer_scope, inner_scope, from_or_join)
 96            _merge_where(outer_scope, inner_scope, from_or_join)
 97            _merge_order(outer_scope, inner_scope)
 98            _merge_hints(outer_scope, inner_scope)
 99            _pop_cte(inner_scope)
100            outer_scope.clear_cache()
101    return expression
102
103
104def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
105    for outer_scope in traverse_scope(expression):
106        for subquery in outer_scope.derived_tables:
107            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
108            alias = subquery.alias_or_name
109            inner_scope = outer_scope.sources[alias]
110            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
111                _rename_inner_sources(outer_scope, inner_scope, alias)
112                _merge_from(outer_scope, inner_scope, subquery, alias)
113                _merge_expressions(outer_scope, inner_scope, alias)
114                _merge_joins(outer_scope, inner_scope, from_or_join)
115                _merge_where(outer_scope, inner_scope, from_or_join)
116                _merge_order(outer_scope, inner_scope)
117                _merge_hints(outer_scope, inner_scope)
118                outer_scope.clear_cache()
119
120    return expression
121
122
123def _mergeable(
124    outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
125) -> bool:
126    """
127    Return True if `inner_select` can be merged into outer query.
128    """
129    inner_select = inner_scope.expression.unnest()
130
131    def _is_a_window_expression_in_unmergable_operation():
132        window_aliases = {s.alias_or_name for s in inner_select.selects if s.find(exp.Window)}
133        inner_select_name = from_or_join.alias_or_name
134        unmergable_window_columns = [
135            column
136            for column in outer_scope.columns
137            if column.find_ancestor(
138                exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
139            )
140        ]
141        window_expressions_in_unmergable = [
142            column
143            for column in unmergable_window_columns
144            if column.table == inner_select_name and column.name in window_aliases
145        ]
146        return any(window_expressions_in_unmergable)
147
148    def _outer_select_joins_on_inner_select_join():
149        """
150        All columns from the inner select in the ON clause must be from the first FROM table.
151
152        That is, this can be merged:
153            SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
154                                         ^^^           ^
155        But this can't:
156            SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
157                                         ^^^                  ^
158        """
159        if not isinstance(from_or_join, exp.Join):
160            return False
161
162        alias = from_or_join.alias_or_name
163
164        on = from_or_join.args.get("on")
165        if not on:
166            return False
167        selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
168        inner_from = inner_scope.expression.args.get("from")
169        if not inner_from:
170            return False
171        inner_from_table = inner_from.alias_or_name
172        inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
173        return any(
174            col.table != inner_from_table
175            for selection in selections
176            for col in inner_projections[selection].find_all(exp.Column)
177        )
178
179    def _is_recursive():
180        # Recursive CTEs look like this:
181        #     WITH RECURSIVE cte AS (
182        #       SELECT * FROM x  <-- inner scope
183        #       UNION ALL
184        #       SELECT * FROM cte  <-- outer scope
185        #     )
186        cte = inner_scope.expression.parent
187        node = outer_scope.expression.parent
188
189        while node:
190            if node is cte:
191                return True
192            node = node.parent
193        return False
194
195    return (
196        isinstance(outer_scope.expression, exp.Select)
197        and not outer_scope.expression.is_star
198        and isinstance(inner_select, exp.Select)
199        and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
200        and inner_select.args.get("from") is not None
201        and not outer_scope.pivots
202        and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
203        and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
204        and not (
205            isinstance(from_or_join, exp.Join)
206            and inner_select.args.get("where")
207            and from_or_join.side in ("FULL", "LEFT", "RIGHT")
208        )
209        and not (
210            isinstance(from_or_join, exp.From)
211            and inner_select.args.get("where")
212            and any(
213                j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
214            )
215        )
216        and not _outer_select_joins_on_inner_select_join()
217        and not _is_a_window_expression_in_unmergable_operation()
218        and not _is_recursive()
219        and not (inner_select.args.get("order") and outer_scope.is_union)
220    )
221
222
223def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
224    """
225    Renames any sources in the inner query that conflict with names in the outer query.
226    """
227    inner_taken = set(inner_scope.selected_sources)
228    outer_taken = set(outer_scope.selected_sources)
229    conflicts = outer_taken.intersection(inner_taken)
230    conflicts -= {alias}
231
232    taken = outer_taken.union(inner_taken)
233
234    for conflict in conflicts:
235        new_name = find_new_name(taken, conflict)
236
237        source, _ = inner_scope.selected_sources[conflict]
238        new_alias = exp.to_identifier(new_name)
239
240        if isinstance(source, exp.Table) and source.alias:
241            source.set("alias", new_alias)
242        elif isinstance(source, exp.Table):
243            source.replace(exp.alias_(source, new_alias))
244        elif isinstance(source.parent, exp.Subquery):
245            source.parent.set("alias", exp.TableAlias(this=new_alias))
246
247        for column in inner_scope.source_columns(conflict):
248            column.set("table", exp.to_identifier(new_name))
249
250        inner_scope.rename_source(conflict, new_name)
251
252
253def _merge_from(
254    outer_scope: Scope,
255    inner_scope: Scope,
256    node_to_replace: t.Union[exp.Subquery, exp.Table],
257    alias: str,
258) -> None:
259    """
260    Merge FROM clause of inner query into outer query.
261    """
262    new_subquery = inner_scope.expression.args["from"].this
263    new_subquery.set("joins", node_to_replace.args.get("joins"))
264    node_to_replace.replace(new_subquery)
265    for join_hint in outer_scope.join_hints:
266        tables = join_hint.find_all(exp.Table)
267        for table in tables:
268            if table.alias_or_name == node_to_replace.alias_or_name:
269                table.set("this", exp.to_identifier(new_subquery.alias_or_name))
270    outer_scope.remove_source(alias)
271    outer_scope.add_source(
272        new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
273    )
274
275
276def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
277    """
278    Merge JOIN clauses of inner query into outer query.
279    """
280
281    new_joins = []
282
283    joins = inner_scope.expression.args.get("joins") or []
284    for join in joins:
285        new_joins.append(join)
286        outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
287
288    if new_joins:
289        outer_joins = outer_scope.expression.args.get("joins", [])
290
291        # Maintain the join order
292        if isinstance(from_or_join, exp.From):
293            position = 0
294        else:
295            position = outer_joins.index(from_or_join) + 1
296        outer_joins[position:position] = new_joins
297
298        outer_scope.expression.set("joins", outer_joins)
299
300
301def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
302    """
303    Merge projections of inner query into outer query.
304
305    Args:
306        outer_scope (sqlglot.optimizer.scope.Scope)
307        inner_scope (sqlglot.optimizer.scope.Scope)
308        alias (str)
309    """
310    # Collect all columns that reference the alias of the inner query
311    outer_columns = defaultdict(list)
312    for column in outer_scope.columns:
313        if column.table == alias:
314            outer_columns[column.name].append(column)
315
316    # Replace columns with the projection expression in the inner query
317    for expression in inner_scope.expression.expressions:
318        projection_name = expression.alias_or_name
319        if not projection_name:
320            continue
321        columns_to_replace = outer_columns.get(projection_name, [])
322
323        expression = expression.unalias()
324        must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
325
326        for column in columns_to_replace:
327            # Ensures we don't alter the intended operator precedence if there's additional
328            # context surrounding the outer expression (i.e. it's not a simple projection).
329            if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
330                expression = exp.paren(expression, copy=False)
331
332            column.replace(expression.copy())
333
334
335def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
336    """
337    Merge WHERE clause of inner query into outer query.
338
339    Args:
340        outer_scope (sqlglot.optimizer.scope.Scope)
341        inner_scope (sqlglot.optimizer.scope.Scope)
342        from_or_join (exp.From|exp.Join)
343    """
344    where = inner_scope.expression.args.get("where")
345    if not where or not where.this:
346        return
347
348    expression = outer_scope.expression
349
350    if isinstance(from_or_join, exp.Join):
351        # Merge predicates from an outer join to the ON clause
352        # if it only has columns that are already joined
353        from_ = expression.args.get("from")
354        sources = {from_.alias_or_name} if from_ else set()
355
356        for join in expression.args["joins"]:
357            source = join.alias_or_name
358            sources.add(source)
359            if source == from_or_join.alias_or_name:
360                break
361
362        if exp.column_table_names(where.this) <= sources:
363            from_or_join.on(where.this, copy=False)
364            from_or_join.set("on", from_or_join.args.get("on"))
365            return
366
367    expression.where(where.this, copy=False)
368
369
370def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
371    """
372    Merge ORDER clause of inner query into outer query.
373
374    Args:
375        outer_scope (sqlglot.optimizer.scope.Scope)
376        inner_scope (sqlglot.optimizer.scope.Scope)
377    """
378    if (
379        any(
380            outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
381        )
382        or len(outer_scope.selected_sources) != 1
383        or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
384    ):
385        return
386
387    outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
388
389
390def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
391    inner_scope_hint = inner_scope.expression.args.get("hint")
392    if not inner_scope_hint:
393        return
394    outer_scope_hint = outer_scope.expression.args.get("hint")
395    if outer_scope_hint:
396        for hint_expression in inner_scope_hint.expressions:
397            outer_scope_hint.append("expressions", hint_expression)
398    else:
399        outer_scope.expression.set("hint", inner_scope_hint)
400
401
402def _pop_cte(inner_scope: Scope) -> None:
403    """
404    Remove CTE from the AST.
405
406    Args:
407        inner_scope (sqlglot.optimizer.scope.Scope)
408    """
409    cte = inner_scope.expression.parent
410    with_ = cte.parent
411    if len(with_.expressions) == 1:
412        with_.pop()
413    else:
414        cte.pop()
def merge_subqueries(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
18def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
19    """
20    Rewrite sqlglot AST to merge derived tables into the outer query.
21
22    This also merges CTEs if they are selected from only once.
23
24    Example:
25        >>> import sqlglot
26        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
27        >>> merge_subqueries(expression).sql()
28        'SELECT x.a FROM x CROSS JOIN y'
29
30    If `leave_tables_isolated` is True, this will not merge inner queries into outer
31    queries if it would result in multiple table selects in a single query:
32        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
33        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
34        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
35
36    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
37
38    Args:
39        expression (sqlglot.Expression): expression to optimize
40        leave_tables_isolated (bool):
41    Returns:
42        sqlglot.Expression: optimized expression
43    """
44    expression = merge_ctes(expression, leave_tables_isolated)
45    expression = merge_derived_tables(expression, leave_tables_isolated)
46    return expression

Rewrite sqlglot AST to merge derived tables into the outer query.

This also merges CTEs if they are selected from only once.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x CROSS JOIN y'

If leave_tables_isolated is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query:

expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'

Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html

Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • leave_tables_isolated (bool):
Returns:

sqlglot.Expression: optimized expression

UNMERGABLE_ARGS = {'qualify', 'sample', 'cluster', 'into', 'settings', 'offset', 'having', 'distinct', 'prewhere', 'kind', 'distribute', 'sort', 'format', 'options', 'limit', 'pivots', 'match', 'group', 'laterals', 'windows', 'operation_modifiers', 'connect', 'locks', 'with'}
SAFE_TO_REPLACE_UNWRAPPED = (<class 'sqlglot.expressions.Column'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.Func'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.Paren'>)
def merge_ctes(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
 71def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 72    scopes = traverse_scope(expression)
 73
 74    # All places where we select from CTEs.
 75    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 76    cte_selections = defaultdict(list)
 77    for outer_scope in scopes:
 78        for table, inner_scope in outer_scope.selected_sources.values():
 79            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 80                cte_selections[id(inner_scope)].append(
 81                    (
 82                        outer_scope,
 83                        inner_scope,
 84                        table,
 85                    )
 86                )
 87
 88    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 89    for outer_scope, inner_scope, table in singular_cte_selections:
 90        from_or_join = table.find_ancestor(exp.From, exp.Join)
 91        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 92            alias = table.alias_or_name
 93            _rename_inner_sources(outer_scope, inner_scope, alias)
 94            _merge_from(outer_scope, inner_scope, table, alias)
 95            _merge_expressions(outer_scope, inner_scope, alias)
 96            _merge_joins(outer_scope, inner_scope, from_or_join)
 97            _merge_where(outer_scope, inner_scope, from_or_join)
 98            _merge_order(outer_scope, inner_scope)
 99            _merge_hints(outer_scope, inner_scope)
100            _pop_cte(inner_scope)
101            outer_scope.clear_cache()
102    return expression
def merge_derived_tables(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
105def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
106    for outer_scope in traverse_scope(expression):
107        for subquery in outer_scope.derived_tables:
108            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
109            alias = subquery.alias_or_name
110            inner_scope = outer_scope.sources[alias]
111            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
112                _rename_inner_sources(outer_scope, inner_scope, alias)
113                _merge_from(outer_scope, inner_scope, subquery, alias)
114                _merge_expressions(outer_scope, inner_scope, alias)
115                _merge_joins(outer_scope, inner_scope, from_or_join)
116                _merge_where(outer_scope, inner_scope, from_or_join)
117                _merge_order(outer_scope, inner_scope)
118                _merge_hints(outer_scope, inner_scope)
119                outer_scope.clear_cache()
120
121    return expression