sqlglot.optimizer.optimize_joins
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot.helper import tsort 7 8JOIN_ATTRS = ("on", "side", "kind", "using", "method") 9 10 11def optimize_joins(expression): 12 """ 13 Removes cross joins if possible and reorder joins based on predicate dependencies. 14 15 Example: 16 >>> from sqlglot import parse_one 17 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 18 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 19 """ 20 21 for select in expression.find_all(exp.Select): 22 joins = select.args.get("joins", []) 23 24 if not _is_reorderable(joins): 25 continue 26 27 references = {} 28 cross_joins = [] 29 30 for join in joins: 31 tables = other_table_names(join) 32 33 if tables: 34 for table in tables: 35 references[table] = references.get(table, []) + [join] 36 else: 37 cross_joins.append((join.alias_or_name, join)) 38 39 for name, join in cross_joins: 40 for dep in references.get(name, []): 41 on = dep.args["on"] 42 43 if isinstance(on, exp.Connector): 44 if len(other_table_names(dep)) < 2: 45 continue 46 47 operator = type(on) 48 for predicate in on.flatten(): 49 if name in exp.column_table_names(predicate): 50 predicate.replace(exp.true()) 51 predicate = exp._combine( 52 [join.args.get("on"), predicate], operator, copy=False 53 ) 54 join.on(predicate, append=False, copy=False) 55 56 expression = reorder_joins(expression) 57 expression = normalize(expression) 58 return expression 59 60 61def reorder_joins(expression): 62 """ 63 Reorder joins by topological sort order based on predicate references. 64 """ 65 for from_ in expression.find_all(exp.From): 66 parent = from_.parent 67 joins = parent.args.get("joins", []) 68 69 if not _is_reorderable(joins): 70 continue 71 72 joins_by_name = {join.alias_or_name: join for join in joins} 73 dag = {name: other_table_names(join) for name, join in joins_by_name.items()} 74 parent.set( 75 "joins", 76 [ 77 joins_by_name[name] 78 for name in tsort(dag) 79 if name != from_.alias_or_name and name in joins_by_name 80 ], 81 ) 82 return expression 83 84 85def normalize(expression): 86 """ 87 Remove INNER and OUTER from joins as they are optional. 88 """ 89 for join in expression.find_all(exp.Join): 90 if not any(join.args.get(k) for k in JOIN_ATTRS): 91 join.set("kind", "CROSS") 92 93 if join.kind == "CROSS": 94 join.set("on", None) 95 else: 96 if join.kind in ("INNER", "OUTER"): 97 join.set("kind", None) 98 99 if not join.args.get("on") and not join.args.get("using"): 100 join.set("on", exp.true()) 101 return expression 102 103 104def other_table_names(join: exp.Join) -> t.Set[str]: 105 on = join.args.get("on") 106 return exp.column_table_names(on, join.alias_or_name) if on else set() 107 108 109def _is_reorderable(joins: t.List[exp.Join]) -> bool: 110 """ 111 Checks if joins can be reordered without changing query semantics. 112 113 Joins with a side (LEFT, RIGHT, FULL) cannot be reordered easily, 114 the order affects which rows are included in the result. 115 116 Example: 117 >>> from sqlglot import parse_one, exp 118 >>> from sqlglot.optimizer.optimize_joins import _is_reorderable 119 >>> ast = parse_one("SELECT * FROM x JOIN y ON x.id = y.id JOIN z ON y.id = z.id") 120 >>> _is_reorderable(ast.find(exp.Select).args.get("joins", [])) 121 True 122 >>> ast = parse_one("SELECT * FROM x LEFT JOIN y ON x.id = y.id JOIN z ON y.id = z.id") 123 >>> _is_reorderable(ast.find(exp.Select).args.get("joins", [])) 124 False 125 """ 126 return not any(join.side for join in joins)
JOIN_ATTRS =
('on', 'side', 'kind', 'using', 'method')
def
optimize_joins(expression):
12def optimize_joins(expression): 13 """ 14 Removes cross joins if possible and reorder joins based on predicate dependencies. 15 16 Example: 17 >>> from sqlglot import parse_one 18 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 19 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 20 """ 21 22 for select in expression.find_all(exp.Select): 23 joins = select.args.get("joins", []) 24 25 if not _is_reorderable(joins): 26 continue 27 28 references = {} 29 cross_joins = [] 30 31 for join in joins: 32 tables = other_table_names(join) 33 34 if tables: 35 for table in tables: 36 references[table] = references.get(table, []) + [join] 37 else: 38 cross_joins.append((join.alias_or_name, join)) 39 40 for name, join in cross_joins: 41 for dep in references.get(name, []): 42 on = dep.args["on"] 43 44 if isinstance(on, exp.Connector): 45 if len(other_table_names(dep)) < 2: 46 continue 47 48 operator = type(on) 49 for predicate in on.flatten(): 50 if name in exp.column_table_names(predicate): 51 predicate.replace(exp.true()) 52 predicate = exp._combine( 53 [join.args.get("on"), predicate], operator, copy=False 54 ) 55 join.on(predicate, append=False, copy=False) 56 57 expression = reorder_joins(expression) 58 expression = normalize(expression) 59 return expression
Removes cross joins if possible and reorder joins based on predicate dependencies.
Example:
>>> from sqlglot import parse_one >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def
reorder_joins(expression):
62def reorder_joins(expression): 63 """ 64 Reorder joins by topological sort order based on predicate references. 65 """ 66 for from_ in expression.find_all(exp.From): 67 parent = from_.parent 68 joins = parent.args.get("joins", []) 69 70 if not _is_reorderable(joins): 71 continue 72 73 joins_by_name = {join.alias_or_name: join for join in joins} 74 dag = {name: other_table_names(join) for name, join in joins_by_name.items()} 75 parent.set( 76 "joins", 77 [ 78 joins_by_name[name] 79 for name in tsort(dag) 80 if name != from_.alias_or_name and name in joins_by_name 81 ], 82 ) 83 return expression
Reorder joins by topological sort order based on predicate references.
def
normalize(expression):
86def normalize(expression): 87 """ 88 Remove INNER and OUTER from joins as they are optional. 89 """ 90 for join in expression.find_all(exp.Join): 91 if not any(join.args.get(k) for k in JOIN_ATTRS): 92 join.set("kind", "CROSS") 93 94 if join.kind == "CROSS": 95 join.set("on", None) 96 else: 97 if join.kind in ("INNER", "OUTER"): 98 join.set("kind", None) 99 100 if not join.args.get("on") and not join.args.get("using"): 101 join.set("on", exp.true()) 102 return expression
Remove INNER and OUTER from joins as they are optional.