Edit on GitHub

sqlglot.optimizer.optimize_joins

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot.helper import tsort
  7
  8JOIN_ATTRS = ("on", "side", "kind", "using", "method")
  9
 10
 11def optimize_joins(expression):
 12    """
 13    Removes cross joins if possible and reorder joins based on predicate dependencies.
 14
 15    Example:
 16        >>> from sqlglot import parse_one
 17        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
 18        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
 19    """
 20
 21    for select in expression.find_all(exp.Select):
 22        joins = select.args.get("joins", [])
 23
 24        if not _is_reorderable(joins):
 25            continue
 26
 27        references = {}
 28        cross_joins = []
 29
 30        for join in joins:
 31            tables = other_table_names(join)
 32
 33            if tables:
 34                for table in tables:
 35                    references[table] = references.get(table, []) + [join]
 36            else:
 37                cross_joins.append((join.alias_or_name, join))
 38
 39        for name, join in cross_joins:
 40            for dep in references.get(name, []):
 41                on = dep.args["on"]
 42
 43                if isinstance(on, exp.Connector):
 44                    if len(other_table_names(dep)) < 2:
 45                        continue
 46
 47                    operator = type(on)
 48                    for predicate in on.flatten():
 49                        if name in exp.column_table_names(predicate):
 50                            predicate.replace(exp.true())
 51                            predicate = exp._combine(
 52                                [join.args.get("on"), predicate], operator, copy=False
 53                            )
 54                            join.on(predicate, append=False, copy=False)
 55
 56    expression = reorder_joins(expression)
 57    expression = normalize(expression)
 58    return expression
 59
 60
 61def reorder_joins(expression):
 62    """
 63    Reorder joins by topological sort order based on predicate references.
 64    """
 65    for from_ in expression.find_all(exp.From):
 66        parent = from_.parent
 67        joins = parent.args.get("joins", [])
 68
 69        if not _is_reorderable(joins):
 70            continue
 71
 72        joins_by_name = {join.alias_or_name: join for join in joins}
 73        dag = {name: other_table_names(join) for name, join in joins_by_name.items()}
 74        parent.set(
 75            "joins",
 76            [
 77                joins_by_name[name]
 78                for name in tsort(dag)
 79                if name != from_.alias_or_name and name in joins_by_name
 80            ],
 81        )
 82    return expression
 83
 84
 85def normalize(expression):
 86    """
 87    Remove INNER and OUTER from joins as they are optional.
 88    """
 89    for join in expression.find_all(exp.Join):
 90        if not any(join.args.get(k) for k in JOIN_ATTRS):
 91            join.set("kind", "CROSS")
 92
 93        if join.kind == "CROSS":
 94            join.set("on", None)
 95        else:
 96            if join.kind in ("INNER", "OUTER"):
 97                join.set("kind", None)
 98
 99            if not join.args.get("on") and not join.args.get("using"):
100                join.set("on", exp.true())
101    return expression
102
103
104def other_table_names(join: exp.Join) -> t.Set[str]:
105    on = join.args.get("on")
106    return exp.column_table_names(on, join.alias_or_name) if on else set()
107
108
109def _is_reorderable(joins: t.List[exp.Join]) -> bool:
110    """
111    Checks if joins can be reordered without changing query semantics.
112
113    Joins with a side (LEFT, RIGHT, FULL) cannot be reordered easily,
114    the order affects which rows are included in the result.
115
116    Example:
117        >>> from sqlglot import parse_one, exp
118        >>> from sqlglot.optimizer.optimize_joins import _is_reorderable
119        >>> ast = parse_one("SELECT * FROM x JOIN y ON x.id = y.id JOIN z ON y.id = z.id")
120        >>> _is_reorderable(ast.find(exp.Select).args.get("joins", []))
121        True
122        >>> ast = parse_one("SELECT * FROM x LEFT JOIN y ON x.id = y.id JOIN z ON y.id = z.id")
123        >>> _is_reorderable(ast.find(exp.Select).args.get("joins", []))
124        False
125    """
126    return not any(join.side for join in joins)
JOIN_ATTRS = ('on', 'side', 'kind', 'using', 'method')
def optimize_joins(expression):
12def optimize_joins(expression):
13    """
14    Removes cross joins if possible and reorder joins based on predicate dependencies.
15
16    Example:
17        >>> from sqlglot import parse_one
18        >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
19        'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
20    """
21
22    for select in expression.find_all(exp.Select):
23        joins = select.args.get("joins", [])
24
25        if not _is_reorderable(joins):
26            continue
27
28        references = {}
29        cross_joins = []
30
31        for join in joins:
32            tables = other_table_names(join)
33
34            if tables:
35                for table in tables:
36                    references[table] = references.get(table, []) + [join]
37            else:
38                cross_joins.append((join.alias_or_name, join))
39
40        for name, join in cross_joins:
41            for dep in references.get(name, []):
42                on = dep.args["on"]
43
44                if isinstance(on, exp.Connector):
45                    if len(other_table_names(dep)) < 2:
46                        continue
47
48                    operator = type(on)
49                    for predicate in on.flatten():
50                        if name in exp.column_table_names(predicate):
51                            predicate.replace(exp.true())
52                            predicate = exp._combine(
53                                [join.args.get("on"), predicate], operator, copy=False
54                            )
55                            join.on(predicate, append=False, copy=False)
56
57    expression = reorder_joins(expression)
58    expression = normalize(expression)
59    return expression

Removes cross joins if possible and reorder joins based on predicate dependencies.

Example:
>>> from sqlglot import parse_one
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
def reorder_joins(expression):
62def reorder_joins(expression):
63    """
64    Reorder joins by topological sort order based on predicate references.
65    """
66    for from_ in expression.find_all(exp.From):
67        parent = from_.parent
68        joins = parent.args.get("joins", [])
69
70        if not _is_reorderable(joins):
71            continue
72
73        joins_by_name = {join.alias_or_name: join for join in joins}
74        dag = {name: other_table_names(join) for name, join in joins_by_name.items()}
75        parent.set(
76            "joins",
77            [
78                joins_by_name[name]
79                for name in tsort(dag)
80                if name != from_.alias_or_name and name in joins_by_name
81            ],
82        )
83    return expression

Reorder joins by topological sort order based on predicate references.

def normalize(expression):
 86def normalize(expression):
 87    """
 88    Remove INNER and OUTER from joins as they are optional.
 89    """
 90    for join in expression.find_all(exp.Join):
 91        if not any(join.args.get(k) for k in JOIN_ATTRS):
 92            join.set("kind", "CROSS")
 93
 94        if join.kind == "CROSS":
 95            join.set("on", None)
 96        else:
 97            if join.kind in ("INNER", "OUTER"):
 98                join.set("kind", None)
 99
100            if not join.args.get("on") and not join.args.get("using"):
101                join.set("on", exp.true())
102    return expression

Remove INNER and OUTER from joins as they are optional.

def other_table_names(join: sqlglot.expressions.Join) -> Set[str]:
105def other_table_names(join: exp.Join) -> t.Set[str]:
106    on = join.args.get("on")
107    return exp.column_table_names(on, join.alias_or_name) if on else set()