sqlglot.optimizer.eliminate_joins
1from sqlglot import expressions as exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import Scope, traverse_scope 4 5 6def eliminate_joins(expression): 7 """ 8 Remove unused joins from an expression. 9 10 This only removes joins when we know that the join condition doesn't produce duplicate rows. 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> eliminate_joins(expression).sql() 17 'SELECT x.a FROM x' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 for scope in traverse_scope(expression): 25 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 26 # It's probably possible to infer this from the outputs of derived tables. 27 # But for now, let's just skip this rule. 28 if scope.unqualified_columns: 29 continue 30 31 joins = scope.expression.args.get("joins", []) 32 33 # Reverse the joins so we can remove chains of unused joins 34 for join in reversed(joins): 35 if join.is_semi_or_anti_join: 36 continue 37 38 alias = join.alias_or_name 39 if _should_eliminate_join(scope, join, alias): 40 join.pop() 41 scope.remove_source(alias) 42 return expression 43 44 45def _should_eliminate_join(scope, join, alias): 46 inner_source = scope.sources.get(alias) 47 return ( 48 isinstance(inner_source, Scope) 49 and not _join_is_used(scope, join, alias) 50 and ( 51 (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) 52 or (not join.args.get("on") and _has_single_output_row(inner_source)) 53 ) 54 ) 55 56 57def _join_is_used(scope, join, alias): 58 # We need to find all columns that reference this join. 59 # But columns in the ON clause shouldn't count. 60 on = join.args.get("on") 61 if on: 62 on_clause_columns = {id(column) for column in on.find_all(exp.Column)} 63 else: 64 on_clause_columns = set() 65 return any( 66 column for column in scope.source_columns(alias) if id(column) not in on_clause_columns 67 ) 68 69 70def _is_joined_on_all_unique_outputs(scope, join): 71 unique_outputs = _unique_outputs(scope) 72 if not unique_outputs: 73 return False 74 75 _, join_keys, _ = join_condition(join) 76 remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} 77 return not remaining_unique_outputs 78 79 80def _unique_outputs(scope): 81 """Determine output columns of `scope` that must have a unique combination per row""" 82 if scope.expression.args.get("distinct"): 83 return set(scope.expression.named_selects) 84 85 group = scope.expression.args.get("group") 86 if group: 87 grouped_expressions = set(group.expressions) 88 grouped_outputs = set() 89 90 unique_outputs = set() 91 for select in scope.expression.selects: 92 output = select.unalias() 93 if output in grouped_expressions: 94 grouped_outputs.add(output) 95 unique_outputs.add(select.alias_or_name) 96 97 # All the grouped expressions must be in the output 98 if not grouped_expressions.difference(grouped_outputs): 99 return unique_outputs 100 else: 101 return set() 102 103 if _has_single_output_row(scope): 104 return set(scope.expression.named_selects) 105 106 return set() 107 108 109def _has_single_output_row(scope): 110 return isinstance(scope.expression, exp.Select) and ( 111 all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) 112 or _is_limit_1(scope) 113 or not scope.expression.args.get("from") 114 ) 115 116 117def _is_limit_1(scope): 118 limit = scope.expression.args.get("limit") 119 return limit and limit.expression.this == "1" 120 121 122def join_condition(join): 123 """ 124 Extract the join condition from a join expression. 125 126 Args: 127 join (exp.Join) 128 Returns: 129 tuple[list[str], list[str], exp.Expression]: 130 Tuple of (source key, join key, remaining predicate) 131 """ 132 name = join.alias_or_name 133 on = (join.args.get("on") or exp.true()).copy() 134 source_key = [] 135 join_key = [] 136 137 def extract_condition(condition): 138 left, right = condition.unnest_operands() 139 left_tables = exp.column_table_names(left) 140 right_tables = exp.column_table_names(right) 141 142 if name in left_tables and name not in right_tables: 143 join_key.append(left) 144 source_key.append(right) 145 condition.replace(exp.true()) 146 elif name in right_tables and name not in left_tables: 147 join_key.append(right) 148 source_key.append(left) 149 condition.replace(exp.true()) 150 151 # find the join keys 152 # SELECT 153 # FROM x 154 # JOIN y 155 # ON x.a = y.b AND y.b > 1 156 # 157 # should pull y.b as the join key and x.a as the source key 158 if normalized(on): 159 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) 160 161 for condition in on.flatten(): 162 if isinstance(condition, exp.EQ): 163 extract_condition(condition) 164 elif normalized(on, dnf=True): 165 conditions = None 166 167 for condition in on.flatten(): 168 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 169 if conditions is None: 170 conditions = parts 171 else: 172 temp = [] 173 for p in parts: 174 cs = [c for c in conditions if p == c] 175 176 if cs: 177 temp.append(p) 178 temp.extend(cs) 179 conditions = temp 180 181 for condition in conditions: 182 extract_condition(condition) 183 184 return source_key, join_key, on
def
eliminate_joins(expression):
7def eliminate_joins(expression): 8 """ 9 Remove unused joins from an expression. 10 11 This only removes joins when we know that the join condition doesn't produce duplicate rows. 12 13 Example: 14 >>> import sqlglot 15 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 16 >>> expression = sqlglot.parse_one(sql) 17 >>> eliminate_joins(expression).sql() 18 'SELECT x.a FROM x' 19 20 Args: 21 expression (sqlglot.Expression): expression to optimize 22 Returns: 23 sqlglot.Expression: optimized expression 24 """ 25 for scope in traverse_scope(expression): 26 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 27 # It's probably possible to infer this from the outputs of derived tables. 28 # But for now, let's just skip this rule. 29 if scope.unqualified_columns: 30 continue 31 32 joins = scope.expression.args.get("joins", []) 33 34 # Reverse the joins so we can remove chains of unused joins 35 for join in reversed(joins): 36 if join.is_semi_or_anti_join: 37 continue 38 39 alias = join.alias_or_name 40 if _should_eliminate_join(scope, join, alias): 41 join.pop() 42 scope.remove_source(alias) 43 return expression
Remove unused joins from an expression.
This only removes joins when we know that the join condition doesn't produce duplicate rows.
Example:
>>> import sqlglot >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" >>> expression = sqlglot.parse_one(sql) >>> eliminate_joins(expression).sql() 'SELECT x.a FROM x'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
join_condition(join):
123def join_condition(join): 124 """ 125 Extract the join condition from a join expression. 126 127 Args: 128 join (exp.Join) 129 Returns: 130 tuple[list[str], list[str], exp.Expression]: 131 Tuple of (source key, join key, remaining predicate) 132 """ 133 name = join.alias_or_name 134 on = (join.args.get("on") or exp.true()).copy() 135 source_key = [] 136 join_key = [] 137 138 def extract_condition(condition): 139 left, right = condition.unnest_operands() 140 left_tables = exp.column_table_names(left) 141 right_tables = exp.column_table_names(right) 142 143 if name in left_tables and name not in right_tables: 144 join_key.append(left) 145 source_key.append(right) 146 condition.replace(exp.true()) 147 elif name in right_tables and name not in left_tables: 148 join_key.append(right) 149 source_key.append(left) 150 condition.replace(exp.true()) 151 152 # find the join keys 153 # SELECT 154 # FROM x 155 # JOIN y 156 # ON x.a = y.b AND y.b > 1 157 # 158 # should pull y.b as the join key and x.a as the source key 159 if normalized(on): 160 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) 161 162 for condition in on.flatten(): 163 if isinstance(condition, exp.EQ): 164 extract_condition(condition) 165 elif normalized(on, dnf=True): 166 conditions = None 167 168 for condition in on.flatten(): 169 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 170 if conditions is None: 171 conditions = parts 172 else: 173 temp = [] 174 for p in parts: 175 cs = [c for c in conditions if p == c] 176 177 if cs: 178 temp.append(p) 179 temp.extend(cs) 180 conditions = temp 181 182 for condition in conditions: 183 extract_condition(condition) 184 185 return source_key, join_key, on
Extract the join condition from a join expression.
Arguments:
- join (exp.Join)
Returns:
tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)