Edit on GitHub

sqlglot.optimizer.eliminate_joins

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.optimizer.normalize import normalized
  7from sqlglot.optimizer.scope import Scope, traverse_scope
  8
  9if t.TYPE_CHECKING:
 10    from sqlglot._typing import E
 11
 12
 13def eliminate_joins(expression: E) -> E:
 14    """
 15    Remove unused joins from an expression.
 16
 17    This only removes joins when we know that the join condition doesn't produce duplicate rows.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
 22        >>> expression = sqlglot.parse_one(sql)
 23        >>> eliminate_joins(expression).sql()
 24        'SELECT x.a FROM x'
 25
 26    Args:
 27        expression: expression to optimize
 28
 29    Returns:
 30        The optimized expression
 31    """
 32    for scope in traverse_scope(expression):
 33        joins = scope.expression.args.get("joins", [])
 34        if not joins:
 35            continue
 36
 37        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
 38        # It's probably possible to infer this from the outputs of derived tables.
 39        # But for now, let's just skip this rule.
 40        if scope.unqualified_columns:
 41            continue
 42
 43        # Reverse the joins so we can remove chains of unused joins
 44        for join in reversed(joins):
 45            if join.is_semi_or_anti_join:
 46                continue
 47
 48            alias = join.alias_or_name
 49            if _should_eliminate_join(scope, join, alias):
 50                join.pop()
 51                scope.remove_source(alias)
 52
 53    return expression
 54
 55
 56def _should_eliminate_join(scope, join, alias):
 57    inner_source = scope.sources.get(alias)
 58    return (
 59        isinstance(inner_source, Scope)
 60        and not _join_is_used(scope, join, alias)
 61        and (
 62            (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
 63            or (not join.args.get("on") and _has_single_output_row(inner_source))
 64        )
 65    )
 66
 67
 68def _join_is_used(scope, join, alias):
 69    # We need to find all columns that reference this join.
 70    # But columns in the ON clause shouldn't count.
 71    on = join.args.get("on")
 72    if on:
 73        on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
 74    else:
 75        on_clause_columns = set()
 76    return any(
 77        column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
 78    )
 79
 80
 81def _is_joined_on_all_unique_outputs(scope, join):
 82    unique_outputs = _unique_outputs(scope)
 83    if not unique_outputs:
 84        return False
 85
 86    _, join_keys, _ = join_condition(join)
 87    remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
 88    return not remaining_unique_outputs
 89
 90
 91def _unique_outputs(scope):
 92    """Determine output columns of `scope` that must have a unique combination per row"""
 93    if scope.expression.args.get("distinct"):
 94        return set(scope.expression.named_selects)
 95
 96    group = scope.expression.args.get("group")
 97    if group:
 98        grouped_expressions = set(group.expressions)
 99        grouped_outputs = set()
100
101        unique_outputs = set()
102        for select in scope.expression.selects:
103            output = select.unalias()
104            if output in grouped_expressions:
105                grouped_outputs.add(output)
106                unique_outputs.add(select.alias_or_name)
107
108        # All the grouped expressions must be in the output
109        if not grouped_expressions.difference(grouped_outputs):
110            return unique_outputs
111        else:
112            return set()
113
114    if _has_single_output_row(scope):
115        return set(scope.expression.named_selects)
116
117    return set()
118
119
120def _has_single_output_row(scope):
121    return isinstance(scope.expression, exp.Select) and (
122        all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
123        or _is_limit_1(scope)
124        or not scope.expression.args.get("from_")
125    )
126
127
128def _is_limit_1(scope):
129    limit = scope.expression.args.get("limit")
130    return limit and limit.expression.this == "1"
131
132
133def join_condition(join):
134    """
135    Extract the join condition from a join expression.
136
137    Args:
138        join (exp.Join)
139    Returns:
140        tuple[list[str], list[str], exp.Expression]:
141            Tuple of (source key, join key, remaining predicate)
142    """
143    name = join.alias_or_name
144    on = (join.args.get("on") or exp.true()).copy()
145    source_key = []
146    join_key = []
147
148    def extract_condition(condition):
149        left, right = condition.unnest_operands()
150        left_tables = exp.column_table_names(left)
151        right_tables = exp.column_table_names(right)
152
153        if name in left_tables and name not in right_tables:
154            join_key.append(left)
155            source_key.append(right)
156            condition.replace(exp.true())
157        elif name in right_tables and name not in left_tables:
158            join_key.append(right)
159            source_key.append(left)
160            condition.replace(exp.true())
161
162    # find the join keys
163    # SELECT
164    # FROM x
165    # JOIN y
166    #   ON x.a = y.b AND y.b > 1
167    #
168    # should pull y.b as the join key and x.a as the source key
169    if normalized(on):
170        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
171
172        for condition in on.flatten():
173            if isinstance(condition, exp.EQ):
174                extract_condition(condition)
175    elif normalized(on, dnf=True):
176        conditions = None
177
178        for condition in on.flatten():
179            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
180            if conditions is None:
181                conditions = parts
182            else:
183                temp = []
184                for p in parts:
185                    cs = [c for c in conditions if p == c]
186
187                    if cs:
188                        temp.append(p)
189                        temp.extend(cs)
190                conditions = temp
191
192        for condition in conditions:
193            extract_condition(condition)
194
195    return source_key, join_key, on
def eliminate_joins(expression: ~E) -> ~E:
14def eliminate_joins(expression: E) -> E:
15    """
16    Remove unused joins from an expression.
17
18    This only removes joins when we know that the join condition doesn't produce duplicate rows.
19
20    Example:
21        >>> import sqlglot
22        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
23        >>> expression = sqlglot.parse_one(sql)
24        >>> eliminate_joins(expression).sql()
25        'SELECT x.a FROM x'
26
27    Args:
28        expression: expression to optimize
29
30    Returns:
31        The optimized expression
32    """
33    for scope in traverse_scope(expression):
34        joins = scope.expression.args.get("joins", [])
35        if not joins:
36            continue
37
38        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
39        # It's probably possible to infer this from the outputs of derived tables.
40        # But for now, let's just skip this rule.
41        if scope.unqualified_columns:
42            continue
43
44        # Reverse the joins so we can remove chains of unused joins
45        for join in reversed(joins):
46            if join.is_semi_or_anti_join:
47                continue
48
49            alias = join.alias_or_name
50            if _should_eliminate_join(scope, join, alias):
51                join.pop()
52                scope.remove_source(alias)
53
54    return expression

Remove unused joins from an expression.

This only removes joins when we know that the join condition doesn't produce duplicate rows.

Example:
>>> import sqlglot
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
>>> expression = sqlglot.parse_one(sql)
>>> eliminate_joins(expression).sql()
'SELECT x.a FROM x'
Arguments:
  • expression: expression to optimize
Returns:

The optimized expression

def join_condition(join):
134def join_condition(join):
135    """
136    Extract the join condition from a join expression.
137
138    Args:
139        join (exp.Join)
140    Returns:
141        tuple[list[str], list[str], exp.Expression]:
142            Tuple of (source key, join key, remaining predicate)
143    """
144    name = join.alias_or_name
145    on = (join.args.get("on") or exp.true()).copy()
146    source_key = []
147    join_key = []
148
149    def extract_condition(condition):
150        left, right = condition.unnest_operands()
151        left_tables = exp.column_table_names(left)
152        right_tables = exp.column_table_names(right)
153
154        if name in left_tables and name not in right_tables:
155            join_key.append(left)
156            source_key.append(right)
157            condition.replace(exp.true())
158        elif name in right_tables and name not in left_tables:
159            join_key.append(right)
160            source_key.append(left)
161            condition.replace(exp.true())
162
163    # find the join keys
164    # SELECT
165    # FROM x
166    # JOIN y
167    #   ON x.a = y.b AND y.b > 1
168    #
169    # should pull y.b as the join key and x.a as the source key
170    if normalized(on):
171        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
172
173        for condition in on.flatten():
174            if isinstance(condition, exp.EQ):
175                extract_condition(condition)
176    elif normalized(on, dnf=True):
177        conditions = None
178
179        for condition in on.flatten():
180            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
181            if conditions is None:
182                conditions = parts
183            else:
184                temp = []
185                for p in parts:
186                    cs = [c for c in conditions if p == c]
187
188                    if cs:
189                        temp.append(p)
190                        temp.extend(cs)
191                conditions = temp
192
193        for condition in conditions:
194            extract_condition(condition)
195
196    return source_key, join_key, on

Extract the join condition from a join expression.

Arguments:
  • join (exp.Join)
Returns:

tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)