sqlglot.optimizer.optimize_joins
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot.helper import tsort 7 8JOIN_ATTRS = ("on", "side", "kind", "using", "method") 9 10 11def optimize_joins(expression): 12 """ 13 Removes cross joins if possible and reorder joins based on predicate dependencies. 14 15 Example: 16 >>> from sqlglot import parse_one 17 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 18 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 19 """ 20 21 for select in expression.find_all(exp.Select): 22 references = {} 23 cross_joins = [] 24 25 for join in select.args.get("joins", []): 26 tables = other_table_names(join) 27 28 if tables: 29 for table in tables: 30 references[table] = references.get(table, []) + [join] 31 else: 32 cross_joins.append((join.alias_or_name, join)) 33 34 for name, join in cross_joins: 35 for dep in references.get(name, []): 36 on = dep.args["on"] 37 38 if isinstance(on, exp.Connector): 39 if len(other_table_names(dep)) < 2: 40 continue 41 42 operator = type(on) 43 for predicate in on.flatten(): 44 if name in exp.column_table_names(predicate): 45 predicate.replace(exp.true()) 46 predicate = exp._combine( 47 [join.args.get("on"), predicate], operator, copy=False 48 ) 49 join.on(predicate, append=False, copy=False) 50 51 expression = reorder_joins(expression) 52 expression = normalize(expression) 53 return expression 54 55 56def reorder_joins(expression): 57 """ 58 Reorder joins by topological sort order based on predicate references. 59 """ 60 for from_ in expression.find_all(exp.From): 61 parent = from_.parent 62 joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} 63 dag = {name: other_table_names(join) for name, join in joins.items()} 64 parent.set( 65 "joins", 66 [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins], 67 ) 68 return expression 69 70 71def normalize(expression): 72 """ 73 Remove INNER and OUTER from joins as they are optional. 74 """ 75 for join in expression.find_all(exp.Join): 76 if not any(join.args.get(k) for k in JOIN_ATTRS): 77 join.set("kind", "CROSS") 78 79 if join.kind == "CROSS": 80 join.set("on", None) 81 else: 82 join.set("kind", None) 83 84 if not join.args.get("on") and not join.args.get("using"): 85 join.set("on", exp.true()) 86 return expression 87 88 89def other_table_names(join: exp.Join) -> t.Set[str]: 90 on = join.args.get("on") 91 return exp.column_table_names(on, join.alias_or_name) if on else set()
JOIN_ATTRS =
('on', 'side', 'kind', 'using', 'method')
def
optimize_joins(expression):
12def optimize_joins(expression): 13 """ 14 Removes cross joins if possible and reorder joins based on predicate dependencies. 15 16 Example: 17 >>> from sqlglot import parse_one 18 >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 19 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' 20 """ 21 22 for select in expression.find_all(exp.Select): 23 references = {} 24 cross_joins = [] 25 26 for join in select.args.get("joins", []): 27 tables = other_table_names(join) 28 29 if tables: 30 for table in tables: 31 references[table] = references.get(table, []) + [join] 32 else: 33 cross_joins.append((join.alias_or_name, join)) 34 35 for name, join in cross_joins: 36 for dep in references.get(name, []): 37 on = dep.args["on"] 38 39 if isinstance(on, exp.Connector): 40 if len(other_table_names(dep)) < 2: 41 continue 42 43 operator = type(on) 44 for predicate in on.flatten(): 45 if name in exp.column_table_names(predicate): 46 predicate.replace(exp.true()) 47 predicate = exp._combine( 48 [join.args.get("on"), predicate], operator, copy=False 49 ) 50 join.on(predicate, append=False, copy=False) 51 52 expression = reorder_joins(expression) 53 expression = normalize(expression) 54 return expression
Removes cross joins if possible and reorder joins based on predicate dependencies.
Example:
>>> from sqlglot import parse_one >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def
reorder_joins(expression):
57def reorder_joins(expression): 58 """ 59 Reorder joins by topological sort order based on predicate references. 60 """ 61 for from_ in expression.find_all(exp.From): 62 parent = from_.parent 63 joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} 64 dag = {name: other_table_names(join) for name, join in joins.items()} 65 parent.set( 66 "joins", 67 [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins], 68 ) 69 return expression
Reorder joins by topological sort order based on predicate references.
def
normalize(expression):
72def normalize(expression): 73 """ 74 Remove INNER and OUTER from joins as they are optional. 75 """ 76 for join in expression.find_all(exp.Join): 77 if not any(join.args.get(k) for k in JOIN_ATTRS): 78 join.set("kind", "CROSS") 79 80 if join.kind == "CROSS": 81 join.set("on", None) 82 else: 83 join.set("kind", None) 84 85 if not join.args.get("on") and not join.args.get("using"): 86 join.set("on", exp.true()) 87 return expression
Remove INNER and OUTER from joins as they are optional.