Edit on GitHub

sqlglot.optimizer.merge_subqueries

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from collections import defaultdict
  6
  7from sqlglot import expressions as exp
  8from sqlglot.helper import find_new_name, seq_get
  9from sqlglot.optimizer.scope import Scope, traverse_scope
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot._typing import E
 13
 14    FromOrJoin = t.Union[exp.From, exp.Join]
 15
 16
 17def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
 18    """
 19    Rewrite sqlglot AST to merge derived tables into the outer query.
 20
 21    This also merges CTEs if they are selected from only once.
 22
 23    Example:
 24        >>> import sqlglot
 25        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 26        >>> merge_subqueries(expression).sql()
 27        'SELECT x.a FROM x CROSS JOIN y'
 28
 29    If `leave_tables_isolated` is True, this will not merge inner queries into outer
 30    queries if it would result in multiple table selects in a single query:
 31        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
 32        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
 33        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
 34
 35    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
 36
 37    Args:
 38        expression (sqlglot.Expression): expression to optimize
 39        leave_tables_isolated (bool):
 40    Returns:
 41        sqlglot.Expression: optimized expression
 42    """
 43    expression = merge_ctes(expression, leave_tables_isolated)
 44    expression = merge_derived_tables(expression, leave_tables_isolated)
 45    return expression
 46
 47
 48# If a derived table has these Select args, it can't be merged
 49UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
 50    "expressions",
 51    "from_",
 52    "joins",
 53    "where",
 54    "order",
 55    "hint",
 56}
 57
 58
 59# Projections in the outer query that are instances of these types can be replaced
 60# without getting wrapped in parentheses, because the precedence won't be altered.
 61SAFE_TO_REPLACE_UNWRAPPED = (
 62    exp.Column,
 63    exp.EQ,
 64    exp.Func,
 65    exp.NEQ,
 66    exp.Paren,
 67)
 68
 69
 70def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 71    scopes = traverse_scope(expression)
 72
 73    # All places where we select from CTEs.
 74    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 75    cte_selections = defaultdict(list)
 76    for outer_scope in scopes:
 77        for table, inner_scope in outer_scope.selected_sources.values():
 78            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 79                cte_selections[id(inner_scope)].append(
 80                    (
 81                        outer_scope,
 82                        inner_scope,
 83                        table,
 84                    )
 85                )
 86
 87    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 88    for outer_scope, inner_scope, table in singular_cte_selections:
 89        from_or_join = table.find_ancestor(exp.From, exp.Join)
 90        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 91            alias = table.alias_or_name
 92            _rename_inner_sources(outer_scope, inner_scope, alias)
 93            _merge_from(outer_scope, inner_scope, table, alias)
 94            _merge_expressions(outer_scope, inner_scope, alias)
 95            _merge_order(outer_scope, inner_scope)
 96            _merge_joins(outer_scope, inner_scope, from_or_join)
 97            _merge_where(outer_scope, inner_scope, from_or_join)
 98            _merge_hints(outer_scope, inner_scope)
 99            _pop_cte(inner_scope)
100            outer_scope.clear_cache()
101    return expression
102
103
104def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
105    for outer_scope in traverse_scope(expression):
106        for subquery in outer_scope.derived_tables:
107            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
108            alias = subquery.alias_or_name
109            inner_scope = outer_scope.sources[alias]
110            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
111                _rename_inner_sources(outer_scope, inner_scope, alias)
112                _merge_from(outer_scope, inner_scope, subquery, alias)
113                _merge_expressions(outer_scope, inner_scope, alias)
114                _merge_order(outer_scope, inner_scope)
115                _merge_joins(outer_scope, inner_scope, from_or_join)
116                _merge_where(outer_scope, inner_scope, from_or_join)
117                _merge_hints(outer_scope, inner_scope)
118                outer_scope.clear_cache()
119
120    return expression
121
122
123def _mergeable(
124    outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
125) -> bool:
126    """
127    Return True if `inner_select` can be merged into outer query.
128    """
129    inner_select = inner_scope.expression.unnest()
130
131    def _is_a_window_expression_in_unmergable_operation():
132        window_aliases = {s.alias_or_name for s in inner_select.selects if s.find(exp.Window)}
133        inner_select_name = from_or_join.alias_or_name
134        unmergable_window_columns = [
135            column
136            for column in outer_scope.columns
137            if column.find_ancestor(
138                exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
139            )
140        ]
141        window_expressions_in_unmergable = [
142            column
143            for column in unmergable_window_columns
144            if column.table == inner_select_name and column.name in window_aliases
145        ]
146        return any(window_expressions_in_unmergable)
147
148    def _outer_select_joins_on_inner_select_join():
149        """
150        All columns from the inner select in the ON clause must be from the first FROM table.
151
152        That is, this can be merged:
153            SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
154                                         ^^^           ^
155        But this can't:
156            SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
157                                         ^^^                  ^
158        """
159        if not isinstance(from_or_join, exp.Join):
160            return False
161
162        alias = from_or_join.alias_or_name
163
164        on = from_or_join.args.get("on")
165        if not on:
166            return False
167        selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
168        inner_from = inner_scope.expression.args.get("from_")
169        if not inner_from:
170            return False
171        inner_from_table = inner_from.alias_or_name
172        inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
173        return any(
174            col.table != inner_from_table
175            for selection in selections
176            for col in inner_projections[selection].find_all(exp.Column)
177        )
178
179    def _is_recursive():
180        # Recursive CTEs look like this:
181        #     WITH RECURSIVE cte AS (
182        #       SELECT * FROM x  <-- inner scope
183        #       UNION ALL
184        #       SELECT * FROM cte  <-- outer scope
185        #     )
186        cte = inner_scope.expression.parent
187        node = outer_scope.expression.parent
188
189        while node:
190            if node is cte:
191                return True
192            node = node.parent
193        return False
194
195    return (
196        isinstance(outer_scope.expression, exp.Select)
197        and not outer_scope.expression.is_star
198        and isinstance(inner_select, exp.Select)
199        and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
200        and inner_select.args.get("from_") is not None
201        and not outer_scope.pivots
202        and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
203        and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
204        and not (isinstance(from_or_join, exp.Join) and inner_select.args.get("joins"))
205        and not (
206            isinstance(from_or_join, exp.Join)
207            and inner_select.args.get("where")
208            and from_or_join.side in ("FULL", "LEFT", "RIGHT")
209        )
210        and not (
211            isinstance(from_or_join, exp.From)
212            and inner_select.args.get("where")
213            and any(
214                j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
215            )
216        )
217        and not _outer_select_joins_on_inner_select_join()
218        and not _is_a_window_expression_in_unmergable_operation()
219        and not _is_recursive()
220        and not (inner_select.args.get("order") and outer_scope.is_union)
221        and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform)
222    )
223
224
225def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
226    """
227    Renames any sources in the inner query that conflict with names in the outer query.
228    """
229    inner_taken = set(inner_scope.selected_sources)
230    outer_taken = set(outer_scope.selected_sources)
231    conflicts = outer_taken.intersection(inner_taken)
232    conflicts -= {alias}
233
234    taken = outer_taken.union(inner_taken)
235
236    for conflict in conflicts:
237        new_name = find_new_name(taken, conflict)
238
239        source, _ = inner_scope.selected_sources[conflict]
240        new_alias = exp.to_identifier(new_name)
241
242        if isinstance(source, exp.Table) and source.alias:
243            source.set("alias", new_alias)
244        elif isinstance(source, exp.Table):
245            source.replace(exp.alias_(source, new_alias))
246        elif isinstance(source.parent, exp.Subquery):
247            source.parent.set("alias", exp.TableAlias(this=new_alias))
248
249        for column in inner_scope.source_columns(conflict):
250            column.set("table", exp.to_identifier(new_name))
251
252        inner_scope.rename_source(conflict, new_name)
253
254
255def _merge_from(
256    outer_scope: Scope,
257    inner_scope: Scope,
258    node_to_replace: t.Union[exp.Subquery, exp.Table],
259    alias: str,
260) -> None:
261    """
262    Merge FROM clause of inner query into outer query.
263    """
264    new_subquery = inner_scope.expression.args["from_"].this
265    new_subquery.set("joins", node_to_replace.args.get("joins"))
266    node_to_replace.replace(new_subquery)
267    for join_hint in outer_scope.join_hints:
268        tables = join_hint.find_all(exp.Table)
269        for table in tables:
270            if table.alias_or_name == node_to_replace.alias_or_name:
271                table.set("this", exp.to_identifier(new_subquery.alias_or_name))
272    outer_scope.remove_source(alias)
273    outer_scope.add_source(
274        new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
275    )
276
277
278def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
279    """
280    Merge JOIN clauses of inner query into outer query.
281    """
282
283    new_joins = []
284
285    joins = inner_scope.expression.args.get("joins") or []
286
287    for join in joins:
288        new_joins.append(join)
289        outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name])
290
291    if new_joins:
292        outer_joins = outer_scope.expression.args.get("joins", [])
293
294        # Maintain the join order
295        if isinstance(from_or_join, exp.From):
296            position = 0
297        else:
298            position = outer_joins.index(from_or_join) + 1
299        outer_joins[position:position] = new_joins
300
301        outer_scope.expression.set("joins", outer_joins)
302
303
304def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
305    """
306    Merge projections of inner query into outer query.
307
308    Args:
309        outer_scope (sqlglot.optimizer.scope.Scope)
310        inner_scope (sqlglot.optimizer.scope.Scope)
311        alias (str)
312    """
313    # Collect all columns that reference the alias of the inner query
314    outer_columns = defaultdict(list)
315    for column in outer_scope.columns:
316        if column.table == alias:
317            outer_columns[column.name].append(column)
318
319    # Replace columns with the projection expression in the inner query
320    for expression in inner_scope.expression.expressions:
321        projection_name = expression.alias_or_name
322        if not projection_name:
323            continue
324        columns_to_replace = outer_columns.get(projection_name, [])
325
326        expression = expression.unalias()
327        must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED)
328
329        is_number = expression.is_number
330
331        for column in columns_to_replace:
332            parent = column.parent
333
334            # Ensures that we don't merge literal numbers in GROUP BY as they have positional context
335            # e.g don't trasform `SELECT a FROM (SELECT 6 AS a) GROUP BY a` to `SELECT 6 AS a GROUP BY 6`,
336            # as this would attempt to GROUP BY the 6th projection instead of the column `a`
337            if is_number and isinstance(parent, exp.Group):
338                column.replace(exp.to_identifier(column.name))
339                continue
340
341            # Ensures we don't alter the intended operator precedence if there's additional
342            # context surrounding the outer expression (i.e. it's not a simple projection).
343            if isinstance(parent, (exp.Unary, exp.Binary)) and must_wrap_expression:
344                expression = exp.paren(expression, copy=False)
345
346            # make sure we do not accidentally change the name of the column
347            if isinstance(parent, exp.Select) and column.name != expression.name:
348                expression = exp.alias_(expression, column.name)
349
350            column.replace(expression.copy())
351
352
353def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
354    """
355    Merge WHERE clause of inner query into outer query.
356
357    Args:
358        outer_scope (sqlglot.optimizer.scope.Scope)
359        inner_scope (sqlglot.optimizer.scope.Scope)
360        from_or_join (exp.From|exp.Join)
361    """
362    where = inner_scope.expression.args.get("where")
363    if not where or not where.this:
364        return
365
366    expression = outer_scope.expression
367
368    if isinstance(from_or_join, exp.Join):
369        # Merge predicates from an outer join to the ON clause
370        # if it only has columns that are already joined
371        from_ = expression.args.get("from_")
372        sources = {from_.alias_or_name} if from_ else set()
373
374        for join in expression.args["joins"]:
375            source = join.alias_or_name
376            sources.add(source)
377            if source == from_or_join.alias_or_name:
378                break
379
380        if exp.column_table_names(where.this) <= sources:
381            from_or_join.on(where.this, copy=False)
382            from_or_join.set("on", from_or_join.args.get("on"))
383            return
384
385    expression.where(where.this, copy=False)
386
387
388def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
389    """
390    Merge ORDER clause of inner query into outer query.
391
392    Args:
393        outer_scope (sqlglot.optimizer.scope.Scope)
394        inner_scope (sqlglot.optimizer.scope.Scope)
395    """
396    if (
397        any(
398            outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
399        )
400        or len(outer_scope.selected_sources) != 1
401        or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
402    ):
403        return
404
405    outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
406
407
408def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
409    inner_scope_hint = inner_scope.expression.args.get("hint")
410    if not inner_scope_hint:
411        return
412    outer_scope_hint = outer_scope.expression.args.get("hint")
413    if outer_scope_hint:
414        for hint_expression in inner_scope_hint.expressions:
415            outer_scope_hint.append("expressions", hint_expression)
416    else:
417        outer_scope.expression.set("hint", inner_scope_hint)
418
419
420def _pop_cte(inner_scope: Scope) -> None:
421    """
422    Remove CTE from the AST.
423
424    Args:
425        inner_scope (sqlglot.optimizer.scope.Scope)
426    """
427    cte = inner_scope.expression.parent
428    with_ = cte.parent
429    if len(with_.expressions) == 1:
430        with_.pop()
431    else:
432        cte.pop()
def merge_subqueries(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
18def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
19    """
20    Rewrite sqlglot AST to merge derived tables into the outer query.
21
22    This also merges CTEs if they are selected from only once.
23
24    Example:
25        >>> import sqlglot
26        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
27        >>> merge_subqueries(expression).sql()
28        'SELECT x.a FROM x CROSS JOIN y'
29
30    If `leave_tables_isolated` is True, this will not merge inner queries into outer
31    queries if it would result in multiple table selects in a single query:
32        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
33        >>> merge_subqueries(expression, leave_tables_isolated=True).sql()
34        'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
35
36    Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
37
38    Args:
39        expression (sqlglot.Expression): expression to optimize
40        leave_tables_isolated (bool):
41    Returns:
42        sqlglot.Expression: optimized expression
43    """
44    expression = merge_ctes(expression, leave_tables_isolated)
45    expression = merge_derived_tables(expression, leave_tables_isolated)
46    return expression

Rewrite sqlglot AST to merge derived tables into the outer query.

This also merges CTEs if they are selected from only once.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x CROSS JOIN y'

If leave_tables_isolated is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query:

expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'

Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html

Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • leave_tables_isolated (bool):
Returns:

sqlglot.Expression: optimized expression

UNMERGABLE_ARGS = {'settings', 'options', 'with_', 'distinct', 'into', 'having', 'qualify', 'laterals', 'group', 'limit', 'cluster', 'pivots', 'distribute', 'offset', 'connect', 'prewhere', 'operation_modifiers', 'match', 'format', 'kind', 'locks', 'sample', 'sort', 'windows'}
SAFE_TO_REPLACE_UNWRAPPED = (<class 'sqlglot.expressions.Column'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.Func'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.Paren'>)
def merge_ctes(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
 71def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
 72    scopes = traverse_scope(expression)
 73
 74    # All places where we select from CTEs.
 75    # We key on the CTE scope so we can detect CTES that are selected from multiple times.
 76    cte_selections = defaultdict(list)
 77    for outer_scope in scopes:
 78        for table, inner_scope in outer_scope.selected_sources.values():
 79            if isinstance(inner_scope, Scope) and inner_scope.is_cte:
 80                cte_selections[id(inner_scope)].append(
 81                    (
 82                        outer_scope,
 83                        inner_scope,
 84                        table,
 85                    )
 86                )
 87
 88    singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
 89    for outer_scope, inner_scope, table in singular_cte_selections:
 90        from_or_join = table.find_ancestor(exp.From, exp.Join)
 91        if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
 92            alias = table.alias_or_name
 93            _rename_inner_sources(outer_scope, inner_scope, alias)
 94            _merge_from(outer_scope, inner_scope, table, alias)
 95            _merge_expressions(outer_scope, inner_scope, alias)
 96            _merge_order(outer_scope, inner_scope)
 97            _merge_joins(outer_scope, inner_scope, from_or_join)
 98            _merge_where(outer_scope, inner_scope, from_or_join)
 99            _merge_hints(outer_scope, inner_scope)
100            _pop_cte(inner_scope)
101            outer_scope.clear_cache()
102    return expression
def merge_derived_tables(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
105def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
106    for outer_scope in traverse_scope(expression):
107        for subquery in outer_scope.derived_tables:
108            from_or_join = subquery.find_ancestor(exp.From, exp.Join)
109            alias = subquery.alias_or_name
110            inner_scope = outer_scope.sources[alias]
111            if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
112                _rename_inner_sources(outer_scope, inner_scope, alias)
113                _merge_from(outer_scope, inner_scope, subquery, alias)
114                _merge_expressions(outer_scope, inner_scope, alias)
115                _merge_order(outer_scope, inner_scope)
116                _merge_joins(outer_scope, inner_scope, from_or_join)
117                _merge_where(outer_scope, inner_scope, from_or_join)
118                _merge_hints(outer_scope, inner_scope)
119                outer_scope.clear_cache()
120
121    return expression