Edit on GitHub

sqlglot.optimizer.optimize_joins

 1from __future__ import annotations
 2
 3import typing as t
 4
 5from sqlglot import exp
 6from sqlglot.helper import tsort
 7
 8JOIN_ATTRS = ("on", "side", "kind", "using", "method")
 9
10
11def optimize_joins(expression):
12    """
13    Removes cross joins if possible and reorder joins based on predicate dependencies.
14
15    Example:
16        >>> from sqlglot import parse_one
17        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
18        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
19    """
20
21    for select in expression.find_all(exp.Select):
22        references = {}
23        cross_joins = []
24
25        for join in select.args.get("joins", []):
26            tables = other_table_names(join)
27
28            if tables:
29                for table in tables:
30                    references[table] = references.get(table, []) + [join]
31            else:
32                cross_joins.append((join.alias_or_name, join))
33
34        for name, join in cross_joins:
35            for dep in references.get(name, []):
36                on = dep.args["on"]
37
38                if isinstance(on, exp.Connector):
39                    if len(other_table_names(dep)) < 2:
40                        continue
41
42                    operator = type(on)
43                    for predicate in on.flatten():
44                        if name in exp.column_table_names(predicate):
45                            predicate.replace(exp.true())
46                            predicate = exp._combine(
47                                [join.args.get("on"), predicate], operator, copy=False
48                            )
49                            join.on(predicate, append=False, copy=False)
50
51    expression = reorder_joins(expression)
52    expression = normalize(expression)
53    return expression
54
55
56def reorder_joins(expression):
57    """
58    Reorder joins by topological sort order based on predicate references.
59    """
60    for from_ in expression.find_all(exp.From):
61        parent = from_.parent
62        joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
63        dag = {name: other_table_names(join) for name, join in joins.items()}
64        parent.set(
65            "joins",
66            [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins],
67        )
68    return expression
69
70
71def normalize(expression):
72    """
73    Remove INNER and OUTER from joins as they are optional.
74    """
75    for join in expression.find_all(exp.Join):
76        if not any(join.args.get(k) for k in JOIN_ATTRS):
77            join.set("kind", "CROSS")
78
79        if join.kind == "CROSS":
80            join.set("on", None)
81        else:
82            join.set("kind", None)
83
84            if not join.args.get("on") and not join.args.get("using"):
85                join.set("on", exp.true())
86    return expression
87
88
89def other_table_names(join: exp.Join) -> t.Set[str]:
90    on = join.args.get("on")
91    return exp.column_table_names(on, join.alias_or_name) if on else set()
JOIN_ATTRS = ('on', 'side', 'kind', 'using', 'method')
def optimize_joins(expression):
12def optimize_joins(expression):
13    """
14    Removes cross joins if possible and reorder joins based on predicate dependencies.
15
16    Example:
17        >>> from sqlglot import parse_one
18        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
19        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
20    """
21
22    for select in expression.find_all(exp.Select):
23        references = {}
24        cross_joins = []
25
26        for join in select.args.get("joins", []):
27            tables = other_table_names(join)
28
29            if tables:
30                for table in tables:
31                    references[table] = references.get(table, []) + [join]
32            else:
33                cross_joins.append((join.alias_or_name, join))
34
35        for name, join in cross_joins:
36            for dep in references.get(name, []):
37                on = dep.args["on"]
38
39                if isinstance(on, exp.Connector):
40                    if len(other_table_names(dep)) < 2:
41                        continue
42
43                    operator = type(on)
44                    for predicate in on.flatten():
45                        if name in exp.column_table_names(predicate):
46                            predicate.replace(exp.true())
47                            predicate = exp._combine(
48                                [join.args.get("on"), predicate], operator, copy=False
49                            )
50                            join.on(predicate, append=False, copy=False)
51
52    expression = reorder_joins(expression)
53    expression = normalize(expression)
54    return expression

Removes cross joins if possible and reorder joins based on predicate dependencies.

Example:
>>> from sqlglot import parse_one
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def reorder_joins(expression):
57def reorder_joins(expression):
58    """
59    Reorder joins by topological sort order based on predicate references.
60    """
61    for from_ in expression.find_all(exp.From):
62        parent = from_.parent
63        joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
64        dag = {name: other_table_names(join) for name, join in joins.items()}
65        parent.set(
66            "joins",
67            [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins],
68        )
69    return expression

Reorder joins by topological sort order based on predicate references.

def normalize(expression):
72def normalize(expression):
73    """
74    Remove INNER and OUTER from joins as they are optional.
75    """
76    for join in expression.find_all(exp.Join):
77        if not any(join.args.get(k) for k in JOIN_ATTRS):
78            join.set("kind", "CROSS")
79
80        if join.kind == "CROSS":
81            join.set("on", None)
82        else:
83            join.set("kind", None)
84
85            if not join.args.get("on") and not join.args.get("using"):
86                join.set("on", exp.true())
87    return expression

Remove INNER and OUTER from joins as they are optional.

def other_table_names(join: sqlglot.expressions.Join) -> Set[str]:
90def other_table_names(join: exp.Join) -> t.Set[str]:
91    on = join.args.get("on")
92    return exp.column_table_names(on, join.alias_or_name) if on else set()