sqlglot.optimizer.merge_subqueries
1from __future__ import annotations 2 3import typing as t 4 5from collections import defaultdict 6 7from sqlglot import expressions as exp 8from sqlglot.helper import find_new_name, seq_get 9from sqlglot.optimizer.scope import Scope, traverse_scope 10 11if t.TYPE_CHECKING: 12 from sqlglot._typing import E 13 14 FromOrJoin = t.Union[exp.From, exp.Join] 15 16 17def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E: 18 """ 19 Rewrite sqlglot AST to merge derived tables into the outer query. 20 21 This also merges CTEs if they are selected from only once. 22 23 Example: 24 >>> import sqlglot 25 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 26 >>> merge_subqueries(expression).sql() 27 'SELECT x.a FROM x CROSS JOIN y' 28 29 If `leave_tables_isolated` is True, this will not merge inner queries into outer 30 queries if it would result in multiple table selects in a single query: 31 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 32 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 33 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' 34 35 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 36 37 Args: 38 expression (sqlglot.Expression): expression to optimize 39 leave_tables_isolated (bool): 40 Returns: 41 sqlglot.Expression: optimized expression 42 """ 43 expression = merge_ctes(expression, leave_tables_isolated) 44 expression = merge_derived_tables(expression, leave_tables_isolated) 45 return expression 46 47 48# If a derived table has these Select args, it can't be merged 49UNMERGABLE_ARGS = set(exp.Select.arg_types) - { 50 "expressions", 51 "from", 52 "joins", 53 "where", 54 "order", 55 "hint", 56} 57 58 59# Projections in the outer query that are instances of these types can be replaced 60# without getting wrapped in parentheses, because the precedence won't be altered. 61SAFE_TO_REPLACE_UNWRAPPED = ( 62 exp.Column, 63 exp.EQ, 64 exp.Func, 65 exp.NEQ, 66 exp.Paren, 67) 68 69 70def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E: 71 scopes = traverse_scope(expression) 72 73 # All places where we select from CTEs. 74 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 75 cte_selections = defaultdict(list) 76 for outer_scope in scopes: 77 for table, inner_scope in outer_scope.selected_sources.values(): 78 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 79 cte_selections[id(inner_scope)].append( 80 ( 81 outer_scope, 82 inner_scope, 83 table, 84 ) 85 ) 86 87 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 88 for outer_scope, inner_scope, table in singular_cte_selections: 89 from_or_join = table.find_ancestor(exp.From, exp.Join) 90 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 91 alias = table.alias_or_name 92 _rename_inner_sources(outer_scope, inner_scope, alias) 93 _merge_from(outer_scope, inner_scope, table, alias) 94 _merge_expressions(outer_scope, inner_scope, alias) 95 _merge_order(outer_scope, inner_scope) 96 _merge_joins(outer_scope, inner_scope, from_or_join) 97 _merge_where(outer_scope, inner_scope, from_or_join) 98 _merge_hints(outer_scope, inner_scope) 99 _pop_cte(inner_scope) 100 outer_scope.clear_cache() 101 return expression 102 103 104def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E: 105 for outer_scope in traverse_scope(expression): 106 for subquery in outer_scope.derived_tables: 107 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 108 alias = subquery.alias_or_name 109 inner_scope = outer_scope.sources[alias] 110 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 111 _rename_inner_sources(outer_scope, inner_scope, alias) 112 _merge_from(outer_scope, inner_scope, subquery, alias) 113 _merge_expressions(outer_scope, inner_scope, alias) 114 _merge_order(outer_scope, inner_scope) 115 _merge_joins(outer_scope, inner_scope, from_or_join) 116 _merge_where(outer_scope, inner_scope, from_or_join) 117 _merge_hints(outer_scope, inner_scope) 118 outer_scope.clear_cache() 119 120 return expression 121 122 123def _mergeable( 124 outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin 125) -> bool: 126 """ 127 Return True if `inner_select` can be merged into outer query. 128 """ 129 inner_select = inner_scope.expression.unnest() 130 131 def _is_a_window_expression_in_unmergable_operation(): 132 window_aliases = {s.alias_or_name for s in inner_select.selects if s.find(exp.Window)} 133 inner_select_name = from_or_join.alias_or_name 134 unmergable_window_columns = [ 135 column 136 for column in outer_scope.columns 137 if column.find_ancestor( 138 exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc 139 ) 140 ] 141 window_expressions_in_unmergable = [ 142 column 143 for column in unmergable_window_columns 144 if column.table == inner_select_name and column.name in window_aliases 145 ] 146 return any(window_expressions_in_unmergable) 147 148 def _outer_select_joins_on_inner_select_join(): 149 """ 150 All columns from the inner select in the ON clause must be from the first FROM table. 151 152 That is, this can be merged: 153 SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a 154 ^^^ ^ 155 But this can't: 156 SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a 157 ^^^ ^ 158 """ 159 if not isinstance(from_or_join, exp.Join): 160 return False 161 162 alias = from_or_join.alias_or_name 163 164 on = from_or_join.args.get("on") 165 if not on: 166 return False 167 selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] 168 inner_from = inner_scope.expression.args.get("from") 169 if not inner_from: 170 return False 171 inner_from_table = inner_from.alias_or_name 172 inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects} 173 return any( 174 col.table != inner_from_table 175 for selection in selections 176 for col in inner_projections[selection].find_all(exp.Column) 177 ) 178 179 def _is_recursive(): 180 # Recursive CTEs look like this: 181 # WITH RECURSIVE cte AS ( 182 # SELECT * FROM x <-- inner scope 183 # UNION ALL 184 # SELECT * FROM cte <-- outer scope 185 # ) 186 cte = inner_scope.expression.parent 187 node = outer_scope.expression.parent 188 189 while node: 190 if node is cte: 191 return True 192 node = node.parent 193 return False 194 195 return ( 196 isinstance(outer_scope.expression, exp.Select) 197 and not outer_scope.expression.is_star 198 and isinstance(inner_select, exp.Select) 199 and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) 200 and inner_select.args.get("from") is not None 201 and not outer_scope.pivots 202 and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) 203 and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) 204 and not (isinstance(from_or_join, exp.Join) and inner_select.args.get("joins")) 205 and not ( 206 isinstance(from_or_join, exp.Join) 207 and inner_select.args.get("where") 208 and from_or_join.side in ("FULL", "LEFT", "RIGHT") 209 ) 210 and not ( 211 isinstance(from_or_join, exp.From) 212 and inner_select.args.get("where") 213 and any( 214 j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", []) 215 ) 216 ) 217 and not _outer_select_joins_on_inner_select_join() 218 and not _is_a_window_expression_in_unmergable_operation() 219 and not _is_recursive() 220 and not (inner_select.args.get("order") and outer_scope.is_union) 221 and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform) 222 ) 223 224 225def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: 226 """ 227 Renames any sources in the inner query that conflict with names in the outer query. 228 """ 229 inner_taken = set(inner_scope.selected_sources) 230 outer_taken = set(outer_scope.selected_sources) 231 conflicts = outer_taken.intersection(inner_taken) 232 conflicts -= {alias} 233 234 taken = outer_taken.union(inner_taken) 235 236 for conflict in conflicts: 237 new_name = find_new_name(taken, conflict) 238 239 source, _ = inner_scope.selected_sources[conflict] 240 new_alias = exp.to_identifier(new_name) 241 242 if isinstance(source, exp.Table) and source.alias: 243 source.set("alias", new_alias) 244 elif isinstance(source, exp.Table): 245 source.replace(exp.alias_(source, new_alias)) 246 elif isinstance(source.parent, exp.Subquery): 247 source.parent.set("alias", exp.TableAlias(this=new_alias)) 248 249 for column in inner_scope.source_columns(conflict): 250 column.set("table", exp.to_identifier(new_name)) 251 252 inner_scope.rename_source(conflict, new_name) 253 254 255def _merge_from( 256 outer_scope: Scope, 257 inner_scope: Scope, 258 node_to_replace: t.Union[exp.Subquery, exp.Table], 259 alias: str, 260) -> None: 261 """ 262 Merge FROM clause of inner query into outer query. 263 """ 264 new_subquery = inner_scope.expression.args["from"].this 265 new_subquery.set("joins", node_to_replace.args.get("joins")) 266 node_to_replace.replace(new_subquery) 267 for join_hint in outer_scope.join_hints: 268 tables = join_hint.find_all(exp.Table) 269 for table in tables: 270 if table.alias_or_name == node_to_replace.alias_or_name: 271 table.set("this", exp.to_identifier(new_subquery.alias_or_name)) 272 outer_scope.remove_source(alias) 273 outer_scope.add_source( 274 new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] 275 ) 276 277 278def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None: 279 """ 280 Merge JOIN clauses of inner query into outer query. 281 """ 282 283 new_joins = [] 284 285 joins = inner_scope.expression.args.get("joins") or [] 286 287 for join in joins: 288 new_joins.append(join) 289 outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) 290 291 if new_joins: 292 outer_joins = outer_scope.expression.args.get("joins", []) 293 294 # Maintain the join order 295 if isinstance(from_or_join, exp.From): 296 position = 0 297 else: 298 position = outer_joins.index(from_or_join) + 1 299 outer_joins[position:position] = new_joins 300 301 outer_scope.expression.set("joins", outer_joins) 302 303 304def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: 305 """ 306 Merge projections of inner query into outer query. 307 308 Args: 309 outer_scope (sqlglot.optimizer.scope.Scope) 310 inner_scope (sqlglot.optimizer.scope.Scope) 311 alias (str) 312 """ 313 # Collect all columns that reference the alias of the inner query 314 outer_columns = defaultdict(list) 315 for column in outer_scope.columns: 316 if column.table == alias: 317 outer_columns[column.name].append(column) 318 319 # Replace columns with the projection expression in the inner query 320 for expression in inner_scope.expression.expressions: 321 projection_name = expression.alias_or_name 322 if not projection_name: 323 continue 324 columns_to_replace = outer_columns.get(projection_name, []) 325 326 expression = expression.unalias() 327 must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED) 328 329 for column in columns_to_replace: 330 # Ensures we don't alter the intended operator precedence if there's additional 331 # context surrounding the outer expression (i.e. it's not a simple projection). 332 if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression: 333 expression = exp.paren(expression, copy=False) 334 335 # make sure we do not accidentally change the name of the column 336 if isinstance(column.parent, exp.Select) and column.name != expression.name: 337 expression = exp.alias_(expression, column.name) 338 339 column.replace(expression.copy()) 340 341 342def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None: 343 """ 344 Merge WHERE clause of inner query into outer query. 345 346 Args: 347 outer_scope (sqlglot.optimizer.scope.Scope) 348 inner_scope (sqlglot.optimizer.scope.Scope) 349 from_or_join (exp.From|exp.Join) 350 """ 351 where = inner_scope.expression.args.get("where") 352 if not where or not where.this: 353 return 354 355 expression = outer_scope.expression 356 357 if isinstance(from_or_join, exp.Join): 358 # Merge predicates from an outer join to the ON clause 359 # if it only has columns that are already joined 360 from_ = expression.args.get("from") 361 sources = {from_.alias_or_name} if from_ else set() 362 363 for join in expression.args["joins"]: 364 source = join.alias_or_name 365 sources.add(source) 366 if source == from_or_join.alias_or_name: 367 break 368 369 if exp.column_table_names(where.this) <= sources: 370 from_or_join.on(where.this, copy=False) 371 from_or_join.set("on", from_or_join.args.get("on")) 372 return 373 374 expression.where(where.this, copy=False) 375 376 377def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None: 378 """ 379 Merge ORDER clause of inner query into outer query. 380 381 Args: 382 outer_scope (sqlglot.optimizer.scope.Scope) 383 inner_scope (sqlglot.optimizer.scope.Scope) 384 """ 385 if ( 386 any( 387 outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] 388 ) 389 or len(outer_scope.selected_sources) != 1 390 or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) 391 ): 392 return 393 394 outer_scope.expression.set("order", inner_scope.expression.args.get("order")) 395 396 397def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None: 398 inner_scope_hint = inner_scope.expression.args.get("hint") 399 if not inner_scope_hint: 400 return 401 outer_scope_hint = outer_scope.expression.args.get("hint") 402 if outer_scope_hint: 403 for hint_expression in inner_scope_hint.expressions: 404 outer_scope_hint.append("expressions", hint_expression) 405 else: 406 outer_scope.expression.set("hint", inner_scope_hint) 407 408 409def _pop_cte(inner_scope: Scope) -> None: 410 """ 411 Remove CTE from the AST. 412 413 Args: 414 inner_scope (sqlglot.optimizer.scope.Scope) 415 """ 416 cte = inner_scope.expression.parent 417 with_ = cte.parent 418 if len(with_.expressions) == 1: 419 with_.pop() 420 else: 421 cte.pop()
def
merge_subqueries(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
18def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E: 19 """ 20 Rewrite sqlglot AST to merge derived tables into the outer query. 21 22 This also merges CTEs if they are selected from only once. 23 24 Example: 25 >>> import sqlglot 26 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 27 >>> merge_subqueries(expression).sql() 28 'SELECT x.a FROM x CROSS JOIN y' 29 30 If `leave_tables_isolated` is True, this will not merge inner queries into outer 31 queries if it would result in multiple table selects in a single query: 32 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") 33 >>> merge_subqueries(expression, leave_tables_isolated=True).sql() 34 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' 35 36 Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html 37 38 Args: 39 expression (sqlglot.Expression): expression to optimize 40 leave_tables_isolated (bool): 41 Returns: 42 sqlglot.Expression: optimized expression 43 """ 44 expression = merge_ctes(expression, leave_tables_isolated) 45 expression = merge_derived_tables(expression, leave_tables_isolated) 46 return expression
Rewrite sqlglot AST to merge derived tables into the outer query.
This also merges CTEs if they are selected from only once.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") >>> merge_subqueries(expression).sql() 'SELECT x.a FROM x CROSS JOIN y'
If leave_tables_isolated
is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") merge_subqueries(expression, leave_tables_isolated=True).sql() 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Arguments:
- expression (sqlglot.Expression): expression to optimize
- leave_tables_isolated (bool):
Returns:
sqlglot.Expression: optimized expression
UNMERGABLE_ARGS =
{'operation_modifiers', 'into', 'options', 'distribute', 'connect', 'locks', 'group', 'having', 'cluster', 'distinct', 'limit', 'match', 'format', 'qualify', 'with', 'windows', 'pivots', 'sample', 'settings', 'laterals', 'sort', 'prewhere', 'kind', 'offset'}
SAFE_TO_REPLACE_UNWRAPPED =
(<class 'sqlglot.expressions.Column'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.Func'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.Paren'>)
def
merge_ctes(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
71def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E: 72 scopes = traverse_scope(expression) 73 74 # All places where we select from CTEs. 75 # We key on the CTE scope so we can detect CTES that are selected from multiple times. 76 cte_selections = defaultdict(list) 77 for outer_scope in scopes: 78 for table, inner_scope in outer_scope.selected_sources.values(): 79 if isinstance(inner_scope, Scope) and inner_scope.is_cte: 80 cte_selections[id(inner_scope)].append( 81 ( 82 outer_scope, 83 inner_scope, 84 table, 85 ) 86 ) 87 88 singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] 89 for outer_scope, inner_scope, table in singular_cte_selections: 90 from_or_join = table.find_ancestor(exp.From, exp.Join) 91 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 92 alias = table.alias_or_name 93 _rename_inner_sources(outer_scope, inner_scope, alias) 94 _merge_from(outer_scope, inner_scope, table, alias) 95 _merge_expressions(outer_scope, inner_scope, alias) 96 _merge_order(outer_scope, inner_scope) 97 _merge_joins(outer_scope, inner_scope, from_or_join) 98 _merge_where(outer_scope, inner_scope, from_or_join) 99 _merge_hints(outer_scope, inner_scope) 100 _pop_cte(inner_scope) 101 outer_scope.clear_cache() 102 return expression
def
merge_derived_tables(expression: ~E, leave_tables_isolated: bool = False) -> ~E:
105def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E: 106 for outer_scope in traverse_scope(expression): 107 for subquery in outer_scope.derived_tables: 108 from_or_join = subquery.find_ancestor(exp.From, exp.Join) 109 alias = subquery.alias_or_name 110 inner_scope = outer_scope.sources[alias] 111 if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): 112 _rename_inner_sources(outer_scope, inner_scope, alias) 113 _merge_from(outer_scope, inner_scope, subquery, alias) 114 _merge_expressions(outer_scope, inner_scope, alias) 115 _merge_order(outer_scope, inner_scope) 116 _merge_joins(outer_scope, inner_scope, from_or_join) 117 _merge_where(outer_scope, inner_scope, from_or_join) 118 _merge_hints(outer_scope, inner_scope) 119 outer_scope.clear_cache() 120 121 return expression