Edit on GitHub

sqlglot.optimizer.merge_subqueries

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from collections import defaultdict
  6
  7from sqlglot import expressions as exp
  8from sqlglot.helper import find_new_name, seq_get
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot._typing import E
 13
 14    FromOrJoin = t.Union[exp.From, exp.Join]
 15
 16
 17def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
 18    """
 19    Rewrite sqlglot AST to merge derived tables into the outer query.
 20
 21    This also merges CTEs if they are selected from only once.
 22
 23    Example:
 24        >>> import sqlglot
 25        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 26        >>> merge_subqueries(expression).sql()
 27        'SELECT x.a FROM x CROSS JOIN y'
 28
 29    If `leave_tables_isolated` is True, this will not merge inner queries into outer
 30    queries if it would result in multiple table selects in a single query:
 31        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 32        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
 33        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
 34
 35    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
 36
 37    Args:
 38        expression (sqlglot.Expression): expression to optimize
 39        leave_tables_isolated (bool):
 40    Returns:
 41        sqlglot.Expression: optimized expression
 42    """
 43    expression = merge_ctes(expression, leave_tables_isolated)
 44    expression = merge_derived_tables(expression, leave_tables_isolated)
 45    return expression
 46
 47
 48# If a derived table has these Select args, it can't be merged
 49UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
 50    "expressions",
 51    "from",
 52    "joins",
 53    "where",
 54    "order",
 55    "hint",
 56}
 57
 58
 59# Projections in the outer query that are instances of these types can be replaced
 60# without getting wrapped in parentheses, because the precedence won't be altered.
 61SAFE_TO_REPLACE_UNWRAPPED = (
 62    exp.Column,
 63    exp.EQ,
 64    exp.Func,
 65    exp.NEQ,
 66    exp.Paren,
 67)
 68
 69
 70def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 71    scopes = traverse_scope(expression)
 72
 73    # All places where we select from CTEs.
 74    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 75    cte_selections = defaultdict(list)
 76    for outer_scope in scopes:
 77        for table, inner_scope in outer_scope.selected_sources.values():
 78            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 79                cte_selections[id(inner_scope)].append(
 80                    (
 81                        outer_scope,
 82                        inner_scope,
 83                        table,
 84                    )
 85                )
 86
 87    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 88    for outer_scope, inner_scope, table in singular_cte_selections:
 89        from_or_join = table.find_ancestor(exp.From, exp.Join)
 90        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 91            alias = table.alias_or_name
 92            _rename_inner_sources(outer_scope, inner_scope, alias)
 93            _merge_from(outer_scope, inner_scope, table, alias)
 94            _merge_expressions(outer_scope, inner_scope, alias)
 95            _merge_order(outer_scope, inner_scope)
 96            _merge_joins(outer_scope, inner_scope, from_or_join)
 97            _merge_where(outer_scope, inner_scope, from_or_join)
 98            _merge_hints(outer_scope, inner_scope)
 99            _pop_cte(inner_scope)
100            outer_scope.clear_cache()
101    return expression
102
103
104def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
105    for outer_scope in traverse_scope(expression):
106        for subquery in outer_scope.derived_tables:
107            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
108            alias = subquery.alias_or_name
109            inner_scope = outer_scope.sources[alias]
110            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
111                _rename_inner_sources(outer_scope, inner_scope, alias)
112                _merge_from(outer_scope, inner_scope, subquery, alias)
113                _merge_expressions(outer_scope, inner_scope, alias)
114                _merge_order(outer_scope, inner_scope)
115                _merge_joins(outer_scope, inner_scope, from_or_join)
116                _merge_where(outer_scope, inner_scope, from_or_join)
117                _merge_hints(outer_scope, inner_scope)
118                outer_scope.clear_cache()
119
120    return expression
121
122
123def _mergeable(
124    outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
125) -> bool:
126    """
127    Return True if `inner_select` can be merged into outer query.
128    """
129    inner_select = inner_scope.expression.unnest()
130
131    def _is_a_window_expression_in_unmergable_operation():
132        window_aliases = {s.alias_or_name for s in inner_select.selects if s.find(exp.Window)}
133        inner_select_name = from_or_join.alias_or_name
134        unmergable_window_columns = [
135            column
136            for column in outer_scope.columns
137            if column.find_ancestor(
138                exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
139            )
140        ]
141        window_expressions_in_unmergable = [
142            column
143            for column in unmergable_window_columns
144            if column.table == inner_select_name and column.name in window_aliases
145        ]
146        return any(window_expressions_in_unmergable)
147
148    def _outer_select_joins_on_inner_select_join():
149        """
150        All columns from the inner select in the ON clause must be from the first FROM table.
151
152        That is, this can be merged:
153            SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
154                                         ^^^           ^
155        But this can't:
156            SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
157                                         ^^^                  ^
158        """
159        if not isinstance(from_or_join, exp.Join):
160            return False
161
162        alias = from_or_join.alias_or_name
163
164        on = from_or_join.args.get("on")
165        if not on:
166            return False
167        selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
168        inner_from = inner_scope.expression.args.get("from")
169        if not inner_from:
170            return False
171        inner_from_table = inner_from.alias_or_name
172        inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
173        return any(
174            col.table != inner_from_table
175            for selection in selections
176            for col in inner_projections[selection].find_all(exp.Column)
177        )
178
179    def _is_recursive():
180        # Recursive CTEs look like this:
181        #     WITH RECURSIVE cte AS (
182        #       SELECT * FROM x  <-- inner scope
183        #       UNION ALL
184        #       SELECT * FROM cte  <-- outer scope
185        #     )
186        cte = inner_scope.expression.parent
187        node = outer_scope.expression.parent
188
189        while node:
190            if node is cte:
191                return True
192            node = node.parent
193        return False
194
195    return (
196        isinstance(outer_scope.expression, exp.Select)
197        and not outer_scope.expression.is_star
198        and isinstance(inner_select, exp.Select)
199        and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
200        and inner_select.args.get("from") is not None
201        and not outer_scope.pivots
202        and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
203        and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
204        and not (isinstance(from_or_join, exp.Join) and inner_select.args.get("joins"))
205        and not (
206            isinstance(from_or_join, exp.Join)
207            and inner_select.args.get("where")
208            and from_or_join.side in ("FULL", "LEFT", "RIGHT")
209        )
210        and not (
211            isinstance(from_or_join, exp.From)
212            and inner_select.args.get("where")
213            and any(
214                j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
215            )
216        )
217        and not _outer_select_joins_on_inner_select_join()
218        and not _is_a_window_expression_in_unmergable_operation()
219        and not _is_recursive()
220        and not (inner_select.args.get("order") and outer_scope.is_union)
221        and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform)
222    )
223
224
225def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
226    """
227    Renames any sources in the inner query that conflict with names in the outer query.
228    """
229    inner_taken = set(inner_scope.selected_sources)
230    outer_taken = set(outer_scope.selected_sources)
231    conflicts = outer_taken.intersection(inner_taken)
232    conflicts -= {alias}
233
234    taken = outer_taken.union(inner_taken)
235
236    for conflict in conflicts:
237        new_name = find_new_name(taken, conflict)
238
239        source, _ = inner_scope.selected_sources[conflict]
240        new_alias = exp.to_identifier(new_name)
241
242        if isinstance(source, exp.Table) and source.alias:
243            source.set("alias", new_alias)
244        elif isinstance(source, exp.Table):
245            source.replace(exp.alias_(source, new_alias))
246        elif isinstance(source.parent, exp.Subquery):
247            source.parent.set("alias", exp.TableAlias(this=new_alias))
248
249        for column in inner_scope.source_columns(conflict):
250            column.set("table", exp.to_identifier(new_name))
251
252        inner_scope.rename_source(conflict, new_name)
253
254
255def _merge_from(
256    outer_scope: Scope,
257    inner_scope: Scope,
258    node_to_replace: t.Union[exp.Subquery, exp.Table],
259    alias: str,
260) -> None:
261    """
262    Merge FROM clause of inner query into outer query.
263    """
264    new_subquery = inner_scope.expression.args["from"].this
265    new_subquery.set("joins", node_to_replace.args.get("joins"))
266    node_to_replace.replace(new_subquery)
267    for join_hint in outer_scope.join_hints:
268        tables = join_hint.find_all(exp.Table)
269        for table in tables:
270            if table.alias_or_name == node_to_replace.alias_or_name:
271                table.set("this", exp.to_identifier(new_subquery.alias_or_name))
272    outer_scope.remove_source(alias)
273    outer_scope.add_source(
274        new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
275    )
276
277
278def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
279    """
280    Merge JOIN clauses of inner query into outer query.
281    """
282
283    new_joins = []
284
285    joins = inner_scope.expression.args.get("joins") or []
286
287    for join in joins:
288        new_joins.append(join)
289        outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
290
291    if new_joins:
292        outer_joins = outer_scope.expression.args.get("joins", [])
293
294        # Maintain the join order
295        if isinstance(from_or_join, exp.From):
296            position = 0
297        else:
298            position = outer_joins.index(from_or_join) + 1
299        outer_joins[position:position] = new_joins
300
301        outer_scope.expression.set("joins", outer_joins)
302
303
304def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
305    """
306    Merge projections of inner query into outer query.
307
308    Args:
309        outer_scope (sqlglot.optimizer.scope.Scope)
310        inner_scope (sqlglot.optimizer.scope.Scope)
311        alias (str)
312    """
313    # Collect all columns that reference the alias of the inner query
314    outer_columns = defaultdict(list)
315    for column in outer_scope.columns:
316        if column.table == alias:
317            outer_columns[column.name].append(column)
318
319    # Replace columns with the projection expression in the inner query
320    for expression in inner_scope.expression.expressions:
321        projection_name = expression.alias_or_name
322        if not projection_name:
323            continue
324        columns_to_replace = outer_columns.get(projection_name, [])
325
326        expression = expression.unalias()
327        must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
328
329        for column in columns_to_replace:
330            # Ensures we don't alter the intended operator precedence if there's additional
331            # context surrounding the outer expression (i.e. it's not a simple projection).
332            if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
333                expression = exp.paren(expression, copy=False)
334
335            # make sure we do not accidentally change the name of the column
336            if isinstance(column.parent, exp.Select) and column.name != expression.name:
337                expression = exp.alias_(expression, column.name)
338
339            column.replace(expression.copy())
340
341
342def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
343    """
344    Merge WHERE clause of inner query into outer query.
345
346    Args:
347        outer_scope (sqlglot.optimizer.scope.Scope)
348        inner_scope (sqlglot.optimizer.scope.Scope)
349        from_or_join (exp.From|exp.Join)
350    """
351    where = inner_scope.expression.args.get("where")
352    if not where or not where.this:
353        return
354
355    expression = outer_scope.expression
356
357    if isinstance(from_or_join, exp.Join):
358        # Merge predicates from an outer join to the ON clause
359        # if it only has columns that are already joined
360        from_ = expression.args.get("from")
361        sources = {from_.alias_or_name} if from_ else set()
362
363        for join in expression.args["joins"]:
364            source = join.alias_or_name
365            sources.add(source)
366            if source == from_or_join.alias_or_name:
367                break
368
369        if exp.column_table_names(where.this) <= sources:
370            from_or_join.on(where.this, copy=False)
371            from_or_join.set("on", from_or_join.args.get("on"))
372            return
373
374    expression.where(where.this, copy=False)
375
376
377def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
378    """
379    Merge ORDER clause of inner query into outer query.
380
381    Args:
382        outer_scope (sqlglot.optimizer.scope.Scope)
383        inner_scope (sqlglot.optimizer.scope.Scope)
384    """
385    if (
386        any(
387            outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
388        )
389        or len(outer_scope.selected_sources) != 1
390        or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
391    ):
392        return
393
394    outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
395
396
397def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
398    inner_scope_hint = inner_scope.expression.args.get("hint")
399    if not inner_scope_hint:
400        return
401    outer_scope_hint = outer_scope.expression.args.get("hint")
402    if outer_scope_hint:
403        for hint_expression in inner_scope_hint.expressions:
404            outer_scope_hint.append("expressions", hint_expression)
405    else:
406        outer_scope.expression.set("hint", inner_scope_hint)
407
408
409def _pop_cte(inner_scope: Scope) -> None:
410    """
411    Remove CTE from the AST.
412
413    Args:
414        inner_scope (sqlglot.optimizer.scope.Scope)
415    """
416    cte = inner_scope.expression.parent
417    with_ = cte.parent
418    if len(with_.expressions) == 1:
419        with_.pop()
420    else:
421        cte.pop()
def merge_subqueries(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
18def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
19    """
20    Rewrite sqlglot AST to merge derived tables into the outer query.
21
22    This also merges CTEs if they are selected from only once.
23
24    Example:
25        >>> import sqlglot
26        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
27        >>> merge_subqueries(expression).sql()
28        'SELECT x.a FROM x CROSS JOIN y'
29
30    If `leave_tables_isolated` is True, this will not merge inner queries into outer
31    queries if it would result in multiple table selects in a single query:
32        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
33        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
34        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
35
36    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
37
38    Args:
39        expression (sqlglot.Expression): expression to optimize
40        leave_tables_isolated (bool):
41    Returns:
42        sqlglot.Expression: optimized expression
43    """
44    expression = merge_ctes(expression, leave_tables_isolated)
45    expression = merge_derived_tables(expression, leave_tables_isolated)
46    return expression

Rewrite sqlglot AST to merge derived tables into the outer query.

This also merges CTEs if they are selected from only once.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x CROSS JOIN y'

If leave_tables_isolated is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query:

expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'

Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html

Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • leave_tables_isolated (bool):
Returns:

sqlglot.Expression: optimized expression

UNMERGABLE_ARGS = {'operation_modifiers', 'into', 'options', 'distribute', 'connect', 'locks', 'group', 'having', 'cluster', 'distinct', 'limit', 'match', 'format', 'qualify', 'with', 'windows', 'pivots', 'sample', 'settings', 'laterals', 'sort', 'prewhere', 'kind', 'offset'}
SAFE_TO_REPLACE_UNWRAPPED = (<class 'sqlglot.expressions.Column'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.Func'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.Paren'>)
def merge_ctes(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
 71def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 72    scopes = traverse_scope(expression)
 73
 74    # All places where we select from CTEs.
 75    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 76    cte_selections = defaultdict(list)
 77    for outer_scope in scopes:
 78        for table, inner_scope in outer_scope.selected_sources.values():
 79            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 80                cte_selections[id(inner_scope)].append(
 81                    (
 82                        outer_scope,
 83                        inner_scope,
 84                        table,
 85                    )
 86                )
 87
 88    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 89    for outer_scope, inner_scope, table in singular_cte_selections:
 90        from_or_join = table.find_ancestor(exp.From, exp.Join)
 91        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 92            alias = table.alias_or_name
 93            _rename_inner_sources(outer_scope, inner_scope, alias)
 94            _merge_from(outer_scope, inner_scope, table, alias)
 95            _merge_expressions(outer_scope, inner_scope, alias)
 96            _merge_order(outer_scope, inner_scope)
 97            _merge_joins(outer_scope, inner_scope, from_or_join)
 98            _merge_where(outer_scope, inner_scope, from_or_join)
 99            _merge_hints(outer_scope, inner_scope)
100            _pop_cte(inner_scope)
101            outer_scope.clear_cache()
102    return expression
def merge_derived_tables(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
105def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
106    for outer_scope in traverse_scope(expression):
107        for subquery in outer_scope.derived_tables:
108            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
109            alias = subquery.alias_or_name
110            inner_scope = outer_scope.sources[alias]
111            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
112                _rename_inner_sources(outer_scope, inner_scope, alias)
113                _merge_from(outer_scope, inner_scope, subquery, alias)
114                _merge_expressions(outer_scope, inner_scope, alias)
115                _merge_order(outer_scope, inner_scope)
116                _merge_joins(outer_scope, inner_scope, from_or_join)
117                _merge_where(outer_scope, inner_scope, from_or_join)
118                _merge_hints(outer_scope, inner_scope)
119                outer_scope.clear_cache()
120
121    return expression