Edit on GitHub

sqlglot.optimizer.eliminate_joins

  1from sqlglot import expressions as exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4
  5
  6def eliminate_joins(expression):
  7    """
  8    Remove unused joins from an expression.
  9
 10    This only removes joins when we know that the join condition doesn't produce duplicate rows.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
 15        >>> expression = sqlglot.parse_one(sql)
 16        >>> eliminate_joins(expression).sql()
 17        'SELECT x.a FROM x'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to optimize
 21    Returns:
 22        sqlglot.Expression: optimized expression
 23    """
 24    for scope in traverse_scope(expression):
 25        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
 26        # It's probably possible to infer this from the outputs of derived tables.
 27        # But for now, let's just skip this rule.
 28        if scope.unqualified_columns:
 29            continue
 30
 31        joins = scope.expression.args.get("joins", [])
 32
 33        # Reverse the joins so we can remove chains of unused joins
 34        for join in reversed(joins):
 35            if join.is_semi_or_anti_join:
 36                continue
 37
 38            alias = join.alias_or_name
 39            if _should_eliminate_join(scope, join, alias):
 40                join.pop()
 41                scope.remove_source(alias)
 42    return expression
 43
 44
 45def _should_eliminate_join(scope, join, alias):
 46    inner_source = scope.sources.get(alias)
 47    return (
 48        isinstance(inner_source, Scope)
 49        and not _join_is_used(scope, join, alias)
 50        and (
 51            (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
 52            or (not join.args.get("on") and _has_single_output_row(inner_source))
 53        )
 54    )
 55
 56
 57def _join_is_used(scope, join, alias):
 58    # We need to find all columns that reference this join.
 59    # But columns in the ON clause shouldn't count.
 60    on = join.args.get("on")
 61    if on:
 62        on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
 63    else:
 64        on_clause_columns = set()
 65    return any(
 66        column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
 67    )
 68
 69
 70def _is_joined_on_all_unique_outputs(scope, join):
 71    unique_outputs = _unique_outputs(scope)
 72    if not unique_outputs:
 73        return False
 74
 75    _, join_keys, _ = join_condition(join)
 76    remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
 77    return not remaining_unique_outputs
 78
 79
 80def _unique_outputs(scope):
 81    """Determine output columns of `scope` that must have a unique combination per row"""
 82    if scope.expression.args.get("distinct"):
 83        return set(scope.expression.named_selects)
 84
 85    group = scope.expression.args.get("group")
 86    if group:
 87        grouped_expressions = set(group.expressions)
 88        grouped_outputs = set()
 89
 90        unique_outputs = set()
 91        for select in scope.expression.selects:
 92            output = select.unalias()
 93            if output in grouped_expressions:
 94                grouped_outputs.add(output)
 95                unique_outputs.add(select.alias_or_name)
 96
 97        # All the grouped expressions must be in the output
 98        if not grouped_expressions.difference(grouped_outputs):
 99            return unique_outputs
100        else:
101            return set()
102
103    if _has_single_output_row(scope):
104        return set(scope.expression.named_selects)
105
106    return set()
107
108
109def _has_single_output_row(scope):
110    return isinstance(scope.expression, exp.Select) and (
111        all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
112        or _is_limit_1(scope)
113        or not scope.expression.args.get("from")
114    )
115
116
117def _is_limit_1(scope):
118    limit = scope.expression.args.get("limit")
119    return limit and limit.expression.this == "1"
120
121
122def join_condition(join):
123    """
124    Extract the join condition from a join expression.
125
126    Args:
127        join (exp.Join)
128    Returns:
129        tuple[list[str], list[str], exp.Expression]:
130            Tuple of (source key, join key, remaining predicate)
131    """
132    name = join.alias_or_name
133    on = (join.args.get("on") or exp.true()).copy()
134    source_key = []
135    join_key = []
136
137    def extract_condition(condition):
138        left, right = condition.unnest_operands()
139        left_tables = exp.column_table_names(left)
140        right_tables = exp.column_table_names(right)
141
142        if name in left_tables and name not in right_tables:
143            join_key.append(left)
144            source_key.append(right)
145            condition.replace(exp.true())
146        elif name in right_tables and name not in left_tables:
147            join_key.append(right)
148            source_key.append(left)
149            condition.replace(exp.true())
150
151    # find the join keys
152    # SELECT
153    # FROM x
154    # JOIN y
155    #   ON x.a = y.b AND y.b > 1
156    #
157    # should pull y.b as the join key and x.a as the source key
158    if normalized(on):
159        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
160
161        for condition in on.flatten():
162            if isinstance(condition, exp.EQ):
163                extract_condition(condition)
164    elif normalized(on, dnf=True):
165        conditions = None
166
167        for condition in on.flatten():
168            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
169            if conditions is None:
170                conditions = parts
171            else:
172                temp = []
173                for p in parts:
174                    cs = [c for c in conditions if p == c]
175
176                    if cs:
177                        temp.append(p)
178                        temp.extend(cs)
179                conditions = temp
180
181        for condition in conditions:
182            extract_condition(condition)
183
184    return source_key, join_key, on
def eliminate_joins(expression):
 7def eliminate_joins(expression):
 8    """
 9    Remove unused joins from an expression.
10
11    This only removes joins when we know that the join condition doesn't produce duplicate rows.
12
13    Example:
14        >>> import sqlglot
15        >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
16        >>> expression = sqlglot.parse_one(sql)
17        >>> eliminate_joins(expression).sql()
18        'SELECT x.a FROM x'
19
20    Args:
21        expression (sqlglot.Expression): expression to optimize
22    Returns:
23        sqlglot.Expression: optimized expression
24    """
25    for scope in traverse_scope(expression):
26        # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
27        # It's probably possible to infer this from the outputs of derived tables.
28        # But for now, let's just skip this rule.
29        if scope.unqualified_columns:
30            continue
31
32        joins = scope.expression.args.get("joins", [])
33
34        # Reverse the joins so we can remove chains of unused joins
35        for join in reversed(joins):
36            if join.is_semi_or_anti_join:
37                continue
38
39            alias = join.alias_or_name
40            if _should_eliminate_join(scope, join, alias):
41                join.pop()
42                scope.remove_source(alias)
43    return expression

Remove unused joins from an expression.

This only removes joins when we know that the join condition doesn't produce duplicate rows.

Example:
>>> import sqlglot
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
>>> expression = sqlglot.parse_one(sql)
>>> eliminate_joins(expression).sql()
'SELECT x.a FROM x'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def join_condition(join):
123def join_condition(join):
124    """
125    Extract the join condition from a join expression.
126
127    Args:
128        join (exp.Join)
129    Returns:
130        tuple[list[str], list[str], exp.Expression]:
131            Tuple of (source key, join key, remaining predicate)
132    """
133    name = join.alias_or_name
134    on = (join.args.get("on") or exp.true()).copy()
135    source_key = []
136    join_key = []
137
138    def extract_condition(condition):
139        left, right = condition.unnest_operands()
140        left_tables = exp.column_table_names(left)
141        right_tables = exp.column_table_names(right)
142
143        if name in left_tables and name not in right_tables:
144            join_key.append(left)
145            source_key.append(right)
146            condition.replace(exp.true())
147        elif name in right_tables and name not in left_tables:
148            join_key.append(right)
149            source_key.append(left)
150            condition.replace(exp.true())
151
152    # find the join keys
153    # SELECT
154    # FROM x
155    # JOIN y
156    #   ON x.a = y.b AND y.b > 1
157    #
158    # should pull y.b as the join key and x.a as the source key
159    if normalized(on):
160        on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
161
162        for condition in on.flatten():
163            if isinstance(condition, exp.EQ):
164                extract_condition(condition)
165    elif normalized(on, dnf=True):
166        conditions = None
167
168        for condition in on.flatten():
169            parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
170            if conditions is None:
171                conditions = parts
172            else:
173                temp = []
174                for p in parts:
175                    cs = [c for c in conditions if p == c]
176
177                    if cs:
178                        temp.append(p)
179                        temp.extend(cs)
180                conditions = temp
181
182        for condition in conditions:
183            extract_condition(condition)
184
185    return source_key, join_key, on

Extract the join condition from a join expression.

Arguments:
  • join (exp.Join)
Returns:

tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)