sqlglot.optimizer.eliminate_joins
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.optimizer.normalize import normalized 7from sqlglot.optimizer.scope import Scope, traverse_scope 8 9if t.TYPE_CHECKING: 10 from sqlglot._typing import E 11 12 13def eliminate_joins(expression: E) -> E: 14 """ 15 Remove unused joins from an expression. 16 17 This only removes joins when we know that the join condition doesn't produce duplicate rows. 18 19 Example: 20 >>> import sqlglot 21 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 22 >>> expression = sqlglot.parse_one(sql) 23 >>> eliminate_joins(expression).sql() 24 'SELECT x.a FROM x' 25 26 Args: 27 expression: expression to optimize 28 29 Returns: 30 The optimized expression 31 """ 32 for scope in traverse_scope(expression): 33 joins = scope.expression.args.get("joins", []) 34 if not joins: 35 continue 36 37 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 38 # It's probably possible to infer this from the outputs of derived tables. 39 # But for now, let's just skip this rule. 40 if scope.unqualified_columns: 41 continue 42 43 # Reverse the joins so we can remove chains of unused joins 44 for join in reversed(joins): 45 if join.is_semi_or_anti_join: 46 continue 47 48 alias = join.alias_or_name 49 if _should_eliminate_join(scope, join, alias): 50 join.pop() 51 scope.remove_source(alias) 52 53 return expression 54 55 56def _should_eliminate_join(scope, join, alias): 57 inner_source = scope.sources.get(alias) 58 return ( 59 isinstance(inner_source, Scope) 60 and not _join_is_used(scope, join, alias) 61 and ( 62 (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) 63 or (not join.args.get("on") and _has_single_output_row(inner_source)) 64 ) 65 ) 66 67 68def _join_is_used(scope, join, alias): 69 # We need to find all columns that reference this join. 70 # But columns in the ON clause shouldn't count. 71 on = join.args.get("on") 72 if on: 73 on_clause_columns = {id(column) for column in on.find_all(exp.Column)} 74 else: 75 on_clause_columns = set() 76 return any( 77 column for column in scope.source_columns(alias) if id(column) not in on_clause_columns 78 ) 79 80 81def _is_joined_on_all_unique_outputs(scope, join): 82 unique_outputs = _unique_outputs(scope) 83 if not unique_outputs: 84 return False 85 86 _, join_keys, _ = join_condition(join) 87 remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} 88 return not remaining_unique_outputs 89 90 91def _unique_outputs(scope): 92 """Determine output columns of `scope` that must have a unique combination per row""" 93 if scope.expression.args.get("distinct"): 94 return set(scope.expression.named_selects) 95 96 group = scope.expression.args.get("group") 97 if group: 98 grouped_expressions = set(group.expressions) 99 grouped_outputs = set() 100 101 unique_outputs = set() 102 for select in scope.expression.selects: 103 output = select.unalias() 104 if output in grouped_expressions: 105 grouped_outputs.add(output) 106 unique_outputs.add(select.alias_or_name) 107 108 # All the grouped expressions must be in the output 109 if not grouped_expressions.difference(grouped_outputs): 110 return unique_outputs 111 else: 112 return set() 113 114 if _has_single_output_row(scope): 115 return set(scope.expression.named_selects) 116 117 return set() 118 119 120def _has_single_output_row(scope): 121 return isinstance(scope.expression, exp.Select) and ( 122 all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) 123 or _is_limit_1(scope) 124 or not scope.expression.args.get("from_") 125 ) 126 127 128def _is_limit_1(scope): 129 limit = scope.expression.args.get("limit") 130 return limit and limit.expression.this == "1" 131 132 133def join_condition(join): 134 """ 135 Extract the join condition from a join expression. 136 137 Args: 138 join (exp.Join) 139 Returns: 140 tuple[list[str], list[str], exp.Expression]: 141 Tuple of (source key, join key, remaining predicate) 142 """ 143 name = join.alias_or_name 144 on = (join.args.get("on") or exp.true()).copy() 145 source_key = [] 146 join_key = [] 147 148 def extract_condition(condition): 149 left, right = condition.unnest_operands() 150 left_tables = exp.column_table_names(left) 151 right_tables = exp.column_table_names(right) 152 153 if name in left_tables and name not in right_tables: 154 join_key.append(left) 155 source_key.append(right) 156 condition.replace(exp.true()) 157 elif name in right_tables and name not in left_tables: 158 join_key.append(right) 159 source_key.append(left) 160 condition.replace(exp.true()) 161 162 # find the join keys 163 # SELECT 164 # FROM x 165 # JOIN y 166 # ON x.a = y.b AND y.b > 1 167 # 168 # should pull y.b as the join key and x.a as the source key 169 if normalized(on): 170 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) 171 172 for condition in on.flatten(): 173 if isinstance(condition, exp.EQ): 174 extract_condition(condition) 175 elif normalized(on, dnf=True): 176 conditions = None 177 178 for condition in on.flatten(): 179 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 180 if conditions is None: 181 conditions = parts 182 else: 183 temp = [] 184 for p in parts: 185 cs = [c for c in conditions if p == c] 186 187 if cs: 188 temp.append(p) 189 temp.extend(cs) 190 conditions = temp 191 192 for condition in conditions: 193 extract_condition(condition) 194 195 return source_key, join_key, on
def
eliminate_joins(expression: ~E) -> ~E:
14def eliminate_joins(expression: E) -> E: 15 """ 16 Remove unused joins from an expression. 17 18 This only removes joins when we know that the join condition doesn't produce duplicate rows. 19 20 Example: 21 >>> import sqlglot 22 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 23 >>> expression = sqlglot.parse_one(sql) 24 >>> eliminate_joins(expression).sql() 25 'SELECT x.a FROM x' 26 27 Args: 28 expression: expression to optimize 29 30 Returns: 31 The optimized expression 32 """ 33 for scope in traverse_scope(expression): 34 joins = scope.expression.args.get("joins", []) 35 if not joins: 36 continue 37 38 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 39 # It's probably possible to infer this from the outputs of derived tables. 40 # But for now, let's just skip this rule. 41 if scope.unqualified_columns: 42 continue 43 44 # Reverse the joins so we can remove chains of unused joins 45 for join in reversed(joins): 46 if join.is_semi_or_anti_join: 47 continue 48 49 alias = join.alias_or_name 50 if _should_eliminate_join(scope, join, alias): 51 join.pop() 52 scope.remove_source(alias) 53 54 return expression
Remove unused joins from an expression.
This only removes joins when we know that the join condition doesn't produce duplicate rows.
Example:
>>> import sqlglot >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" >>> expression = sqlglot.parse_one(sql) >>> eliminate_joins(expression).sql() 'SELECT x.a FROM x'
Arguments:
- expression: expression to optimize
Returns:
The optimized expression
def
join_condition(join):
134def join_condition(join): 135 """ 136 Extract the join condition from a join expression. 137 138 Args: 139 join (exp.Join) 140 Returns: 141 tuple[list[str], list[str], exp.Expression]: 142 Tuple of (source key, join key, remaining predicate) 143 """ 144 name = join.alias_or_name 145 on = (join.args.get("on") or exp.true()).copy() 146 source_key = [] 147 join_key = [] 148 149 def extract_condition(condition): 150 left, right = condition.unnest_operands() 151 left_tables = exp.column_table_names(left) 152 right_tables = exp.column_table_names(right) 153 154 if name in left_tables and name not in right_tables: 155 join_key.append(left) 156 source_key.append(right) 157 condition.replace(exp.true()) 158 elif name in right_tables and name not in left_tables: 159 join_key.append(right) 160 source_key.append(left) 161 condition.replace(exp.true()) 162 163 # find the join keys 164 # SELECT 165 # FROM x 166 # JOIN y 167 # ON x.a = y.b AND y.b > 1 168 # 169 # should pull y.b as the join key and x.a as the source key 170 if normalized(on): 171 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) 172 173 for condition in on.flatten(): 174 if isinstance(condition, exp.EQ): 175 extract_condition(condition) 176 elif normalized(on, dnf=True): 177 conditions = None 178 179 for condition in on.flatten(): 180 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 181 if conditions is None: 182 conditions = parts 183 else: 184 temp = [] 185 for p in parts: 186 cs = [c for c in conditions if p == c] 187 188 if cs: 189 temp.append(p) 190 temp.extend(cs) 191 conditions = temp 192 193 for condition in conditions: 194 extract_condition(condition) 195 196 return source_key, join_key, on
Extract the join condition from a join expression.
Arguments:
- join (exp.Join)
Returns:
tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)