Edit on GitHub

sqlglot.optimizer.merge_subqueries

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from collections import defaultdict
  6
  7from sqlglot import expressions as exp
  8from sqlglot.helper import find_new_name
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot._typing import E
 13
 14    FromOrJoin = t.Union[exp.From, exp.Join]
 15
 16
 17def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
 18    """
 19    Rewrite sqlglot AST to merge derived tables into the outer query.
 20
 21    This also merges CTEs if they are selected from only once.
 22
 23    Example:
 24        >>> import sqlglot
 25        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 26        >>> merge_subqueries(expression).sql()
 27        'SELECT x.a FROM x CROSS JOIN y'
 28
 29    If `leave_tables_isolated` is True, this will not merge inner queries into outer
 30    queries if it would result in multiple table selects in a single query:
 31        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 32        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
 33        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
 34
 35    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
 36
 37    Args:
 38        expression (sqlglot.Expression): expression to optimize
 39        leave_tables_isolated (bool):
 40    Returns:
 41        sqlglot.Expression: optimized expression
 42    """
 43    expression = merge_ctes(expression, leave_tables_isolated)
 44    expression = merge_derived_tables(expression, leave_tables_isolated)
 45    return expression
 46
 47
 48# If a derived table has these Select args, it can't be merged
 49UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
 50    "expressions",
 51    "from",
 52    "joins",
 53    "where",
 54    "order",
 55    "hint",
 56}
 57
 58
 59# Projections in the outer query that are instances of these types can be replaced
 60# without getting wrapped in parentheses, because the precedence won't be altered.
 61SAFE_TO_REPLACE_UNWRAPPED = (
 62    exp.Column,
 63    exp.EQ,
 64    exp.Func,
 65    exp.NEQ,
 66    exp.Paren,
 67)
 68
 69
 70def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 71    scopes = traverse_scope(expression)
 72
 73    # All places where we select from CTEs.
 74    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 75    cte_selections = defaultdict(list)
 76    for outer_scope in scopes:
 77        for table, inner_scope in outer_scope.selected_sources.values():
 78            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 79                cte_selections[id(inner_scope)].append(
 80                    (
 81                        outer_scope,
 82                        inner_scope,
 83                        table,
 84                    )
 85                )
 86
 87    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 88    for outer_scope, inner_scope, table in singular_cte_selections:
 89        from_or_join = table.find_ancestor(exp.From, exp.Join)
 90        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 91            alias = table.alias_or_name
 92            _rename_inner_sources(outer_scope, inner_scope, alias)
 93            _merge_from(outer_scope, inner_scope, table, alias)
 94            _merge_expressions(outer_scope, inner_scope, alias)
 95            _merge_joins(outer_scope, inner_scope, from_or_join)
 96            _merge_where(outer_scope, inner_scope, from_or_join)
 97            _merge_order(outer_scope, inner_scope)
 98            _merge_hints(outer_scope, inner_scope)
 99            _pop_cte(inner_scope)
100            outer_scope.clear_cache()
101    return expression
102
103
104def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
105    for outer_scope in traverse_scope(expression):
106        for subquery in outer_scope.derived_tables:
107            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
108            alias = subquery.alias_or_name
109            inner_scope = outer_scope.sources[alias]
110            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
111                _rename_inner_sources(outer_scope, inner_scope, alias)
112                _merge_from(outer_scope, inner_scope, subquery, alias)
113                _merge_expressions(outer_scope, inner_scope, alias)
114                _merge_joins(outer_scope, inner_scope, from_or_join)
115                _merge_where(outer_scope, inner_scope, from_or_join)
116                _merge_order(outer_scope, inner_scope)
117                _merge_hints(outer_scope, inner_scope)
118                outer_scope.clear_cache()
119
120    return expression
121
122
123def _mergeable(
124    outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
125) -> bool:
126    """
127    Return True if `inner_select` can be merged into outer query.
128    """
129    inner_select = inner_scope.expression.unnest()
130
131    def _is_a_window_expression_in_unmergable_operation():
132        window_expressions = inner_select.find_all(exp.Window)
133        window_alias_names = {window.parent.alias_or_name for window in window_expressions}
134        inner_select_name = from_or_join.alias_or_name
135        unmergable_window_columns = [
136            column
137            for column in outer_scope.columns
138            if column.find_ancestor(
139                exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
140            )
141        ]
142        window_expressions_in_unmergable = [
143            column
144            for column in unmergable_window_columns
145            if column.table == inner_select_name and column.name in window_alias_names
146        ]
147        return any(window_expressions_in_unmergable)
148
149    def _outer_select_joins_on_inner_select_join():
150        """
151        All columns from the inner select in the ON clause must be from the first FROM table.
152
153        That is, this can be merged:
154            SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
155                                         ^^^           ^
156        But this can't:
157            SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
158                                         ^^^                  ^
159        """
160        if not isinstance(from_or_join, exp.Join):
161            return False
162
163        alias = from_or_join.alias_or_name
164
165        on = from_or_join.args.get("on")
166        if not on:
167            return False
168        selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
169        inner_from = inner_scope.expression.args.get("from")
170        if not inner_from:
171            return False
172        inner_from_table = inner_from.alias_or_name
173        inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
174        return any(
175            col.table != inner_from_table
176            for selection in selections
177            for col in inner_projections[selection].find_all(exp.Column)
178        )
179
180    def _is_recursive():
181        # Recursive CTEs look like this:
182        #     WITH RECURSIVE cte AS (
183        #       SELECT * FROM x  <-- inner scope
184        #       UNION ALL
185        #       SELECT * FROM cte  <-- outer scope
186        #     )
187        cte = inner_scope.expression.parent
188        node = outer_scope.expression.parent
189
190        while node:
191            if node is cte:
192                return True
193            node = node.parent
194        return False
195
196    return (
197        isinstance(outer_scope.expression, exp.Select)
198        and not outer_scope.expression.is_star
199        and isinstance(inner_select, exp.Select)
200        and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
201        and inner_select.args.get("from") is not None
202        and not outer_scope.pivots
203        and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
204        and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
205        and not (
206            isinstance(from_or_join, exp.Join)
207            and inner_select.args.get("where")
208            and from_or_join.side in ("FULL", "LEFT", "RIGHT")
209        )
210        and not (
211            isinstance(from_or_join, exp.From)
212            and inner_select.args.get("where")
213            and any(
214                j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
215            )
216        )
217        and not _outer_select_joins_on_inner_select_join()
218        and not _is_a_window_expression_in_unmergable_operation()
219        and not _is_recursive()
220        and not (inner_select.args.get("order") and outer_scope.is_union)
221    )
222
223
224def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
225    """
226    Renames any sources in the inner query that conflict with names in the outer query.
227    """
228    inner_taken = set(inner_scope.selected_sources)
229    outer_taken = set(outer_scope.selected_sources)
230    conflicts = outer_taken.intersection(inner_taken)
231    conflicts -= {alias}
232
233    taken = outer_taken.union(inner_taken)
234
235    for conflict in conflicts:
236        new_name = find_new_name(taken, conflict)
237
238        source, _ = inner_scope.selected_sources[conflict]
239        new_alias = exp.to_identifier(new_name)
240
241        if isinstance(source, exp.Subquery):
242            source.set("alias", exp.TableAlias(this=new_alias))
243        elif isinstance(source, exp.Table) and source.alias:
244            source.set("alias", new_alias)
245        elif isinstance(source, exp.Table):
246            source.replace(exp.alias_(source, new_alias))
247
248        for column in inner_scope.source_columns(conflict):
249            column.set("table", exp.to_identifier(new_name))
250
251        inner_scope.rename_source(conflict, new_name)
252
253
254def _merge_from(
255    outer_scope: Scope,
256    inner_scope: Scope,
257    node_to_replace: t.Union[exp.Subquery, exp.Table],
258    alias: str,
259) -> None:
260    """
261    Merge FROM clause of inner query into outer query.
262    """
263    new_subquery = inner_scope.expression.args["from"].this
264    new_subquery.set("joins", node_to_replace.args.get("joins"))
265    node_to_replace.replace(new_subquery)
266    for join_hint in outer_scope.join_hints:
267        tables = join_hint.find_all(exp.Table)
268        for table in tables:
269            if table.alias_or_name == node_to_replace.alias_or_name:
270                table.set("this", exp.to_identifier(new_subquery.alias_or_name))
271    outer_scope.remove_source(alias)
272    outer_scope.add_source(
273        new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
274    )
275
276
277def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
278    """
279    Merge JOIN clauses of inner query into outer query.
280    """
281
282    new_joins = []
283
284    joins = inner_scope.expression.args.get("joins") or []
285    for join in joins:
286        new_joins.append(join)
287        outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
288
289    if new_joins:
290        outer_joins = outer_scope.expression.args.get("joins", [])
291
292        # Maintain the join order
293        if isinstance(from_or_join, exp.From):
294            position = 0
295        else:
296            position = outer_joins.index(from_or_join) + 1
297        outer_joins[position:position] = new_joins
298
299        outer_scope.expression.set("joins", outer_joins)
300
301
302def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
303    """
304    Merge projections of inner query into outer query.
305
306    Args:
307        outer_scope (sqlglot.optimizer.scope.Scope)
308        inner_scope (sqlglot.optimizer.scope.Scope)
309        alias (str)
310    """
311    # Collect all columns that reference the alias of the inner query
312    outer_columns = defaultdict(list)
313    for column in outer_scope.columns:
314        if column.table == alias:
315            outer_columns[column.name].append(column)
316
317    # Replace columns with the projection expression in the inner query
318    for expression in inner_scope.expression.expressions:
319        projection_name = expression.alias_or_name
320        if not projection_name:
321            continue
322        columns_to_replace = outer_columns.get(projection_name, [])
323
324        expression = expression.unalias()
325        must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
326
327        for column in columns_to_replace:
328            # Ensures we don't alter the intended operator precedence if there's additional
329            # context surrounding the outer expression (i.e. it's not a simple projection).
330            if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
331                expression = exp.paren(expression, copy=False)
332
333            column.replace(expression.copy())
334
335
336def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
337    """
338    Merge WHERE clause of inner query into outer query.
339
340    Args:
341        outer_scope (sqlglot.optimizer.scope.Scope)
342        inner_scope (sqlglot.optimizer.scope.Scope)
343        from_or_join (exp.From|exp.Join)
344    """
345    where = inner_scope.expression.args.get("where")
346    if not where or not where.this:
347        return
348
349    expression = outer_scope.expression
350
351    if isinstance(from_or_join, exp.Join):
352        # Merge predicates from an outer join to the ON clause
353        # if it only has columns that are already joined
354        from_ = expression.args.get("from")
355        sources = {from_.alias_or_name} if from_ else set()
356
357        for join in expression.args["joins"]:
358            source = join.alias_or_name
359            sources.add(source)
360            if source == from_or_join.alias_or_name:
361                break
362
363        if exp.column_table_names(where.this) <= sources:
364            from_or_join.on(where.this, copy=False)
365            from_or_join.set("on", from_or_join.args.get("on"))
366            return
367
368    expression.where(where.this, copy=False)
369
370
371def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
372    """
373    Merge ORDER clause of inner query into outer query.
374
375    Args:
376        outer_scope (sqlglot.optimizer.scope.Scope)
377        inner_scope (sqlglot.optimizer.scope.Scope)
378    """
379    if (
380        any(
381            outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
382        )
383        or len(outer_scope.selected_sources) != 1
384        or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
385    ):
386        return
387
388    outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
389
390
391def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
392    inner_scope_hint = inner_scope.expression.args.get("hint")
393    if not inner_scope_hint:
394        return
395    outer_scope_hint = outer_scope.expression.args.get("hint")
396    if outer_scope_hint:
397        for hint_expression in inner_scope_hint.expressions:
398            outer_scope_hint.append("expressions", hint_expression)
399    else:
400        outer_scope.expression.set("hint", inner_scope_hint)
401
402
403def _pop_cte(inner_scope: Scope) -> None:
404    """
405    Remove CTE from the AST.
406
407    Args:
408        inner_scope (sqlglot.optimizer.scope.Scope)
409    """
410    cte = inner_scope.expression.parent
411    with_ = cte.parent
412    if len(with_.expressions) == 1:
413        with_.pop()
414    else:
415        cte.pop()
def merge_subqueries(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
18def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
19    """
20    Rewrite sqlglot AST to merge derived tables into the outer query.
21
22    This also merges CTEs if they are selected from only once.
23
24    Example:
25        >>> import sqlglot
26        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
27        >>> merge_subqueries(expression).sql()
28        'SELECT x.a FROM x CROSS JOIN y'
29
30    If `leave_tables_isolated` is True, this will not merge inner queries into outer
31    queries if it would result in multiple table selects in a single query:
32        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
33        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
34        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
35
36    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
37
38    Args:
39        expression (sqlglot.Expression): expression to optimize
40        leave_tables_isolated (bool):
41    Returns:
42        sqlglot.Expression: optimized expression
43    """
44    expression = merge_ctes(expression, leave_tables_isolated)
45    expression = merge_derived_tables(expression, leave_tables_isolated)
46    return expression

Rewrite sqlglot AST to merge derived tables into the outer query.

This also merges CTEs if they are selected from only once.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x CROSS JOIN y'

If leave_tables_isolated is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query:

expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'

Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html

Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • leave_tables_isolated (bool):
Returns:

sqlglot.Expression: optimized expression

UNMERGABLE_ARGS = {'pivots', 'cluster', 'distribute', 'windows', 'locks', 'distinct', 'connect', 'operation_modifiers', 'settings', 'with', 'having', 'qualify', 'limit', 'laterals', 'match', 'options', 'format', 'group', 'sort', 'kind', 'sample', 'offset', 'into', 'prewhere'}
SAFE_TO_REPLACE_UNWRAPPED = (<class 'sqlglot.expressions.Column'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.Func'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.Paren'>)
def merge_ctes(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
 71def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 72    scopes = traverse_scope(expression)
 73
 74    # All places where we select from CTEs.
 75    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 76    cte_selections = defaultdict(list)
 77    for outer_scope in scopes:
 78        for table, inner_scope in outer_scope.selected_sources.values():
 79            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 80                cte_selections[id(inner_scope)].append(
 81                    (
 82                        outer_scope,
 83                        inner_scope,
 84                        table,
 85                    )
 86                )
 87
 88    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 89    for outer_scope, inner_scope, table in singular_cte_selections:
 90        from_or_join = table.find_ancestor(exp.From, exp.Join)
 91        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 92            alias = table.alias_or_name
 93            _rename_inner_sources(outer_scope, inner_scope, alias)
 94            _merge_from(outer_scope, inner_scope, table, alias)
 95            _merge_expressions(outer_scope, inner_scope, alias)
 96            _merge_joins(outer_scope, inner_scope, from_or_join)
 97            _merge_where(outer_scope, inner_scope, from_or_join)
 98            _merge_order(outer_scope, inner_scope)
 99            _merge_hints(outer_scope, inner_scope)
100            _pop_cte(inner_scope)
101            outer_scope.clear_cache()
102    return expression
def merge_derived_tables(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
105def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
106    for outer_scope in traverse_scope(expression):
107        for subquery in outer_scope.derived_tables:
108            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
109            alias = subquery.alias_or_name
110            inner_scope = outer_scope.sources[alias]
111            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
112                _rename_inner_sources(outer_scope, inner_scope, alias)
113                _merge_from(outer_scope, inner_scope, subquery, alias)
114                _merge_expressions(outer_scope, inner_scope, alias)
115                _merge_joins(outer_scope, inner_scope, from_or_join)
116                _merge_where(outer_scope, inner_scope, from_or_join)
117                _merge_order(outer_scope, inner_scope)
118                _merge_hints(outer_scope, inner_scope)
119                outer_scope.clear_cache()
120
121    return expression