sqlglot.optimizer.eliminate_joins
1from sqlglot import expressions as exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import Scope, traverse_scope 4 5 6def eliminate_joins(expression): 7 """ 8 Remove unused joins from an expression. 9 10 This only removes joins when we know that the join condition doesn't produce duplicate rows. 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> eliminate_joins(expression).sql() 17 'SELECT x.a FROM x' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 for scope in traverse_scope(expression): 25 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 26 # It's probably possible to infer this from the outputs of derived tables. 27 # But for now, let's just skip this rule. 28 if scope.unqualified_columns: 29 continue 30 31 joins = scope.expression.args.get("joins", []) 32 33 # Reverse the joins so we can remove chains of unused joins 34 for join in reversed(joins): 35 alias = join.alias_or_name 36 if _should_eliminate_join(scope, join, alias): 37 join.pop() 38 scope.remove_source(alias) 39 return expression 40 41 42def _should_eliminate_join(scope, join, alias): 43 inner_source = scope.sources.get(alias) 44 return ( 45 isinstance(inner_source, Scope) 46 and not _join_is_used(scope, join, alias) 47 and ( 48 (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) 49 or (not join.args.get("on") and _has_single_output_row(inner_source)) 50 ) 51 ) 52 53 54def _join_is_used(scope, join, alias): 55 # We need to find all columns that reference this join. 56 # But columns in the ON clause shouldn't count. 57 on = join.args.get("on") 58 if on: 59 on_clause_columns = {id(column) for column in on.find_all(exp.Column)} 60 else: 61 on_clause_columns = set() 62 return any( 63 column for column in scope.source_columns(alias) if id(column) not in on_clause_columns 64 ) 65 66 67def _is_joined_on_all_unique_outputs(scope, join): 68 unique_outputs = _unique_outputs(scope) 69 if not unique_outputs: 70 return False 71 72 _, join_keys, _ = join_condition(join) 73 remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} 74 return not remaining_unique_outputs 75 76 77def _unique_outputs(scope): 78 """Determine output columns of `scope` that must have a unique combination per row""" 79 if scope.expression.args.get("distinct"): 80 return set(scope.expression.named_selects) 81 82 group = scope.expression.args.get("group") 83 if group: 84 grouped_expressions = set(group.expressions) 85 grouped_outputs = set() 86 87 unique_outputs = set() 88 for select in scope.expression.selects: 89 output = select.unalias() 90 if output in grouped_expressions: 91 grouped_outputs.add(output) 92 unique_outputs.add(select.alias_or_name) 93 94 # All the grouped expressions must be in the output 95 if not grouped_expressions.difference(grouped_outputs): 96 return unique_outputs 97 else: 98 return set() 99 100 if _has_single_output_row(scope): 101 return set(scope.expression.named_selects) 102 103 return set() 104 105 106def _has_single_output_row(scope): 107 return isinstance(scope.expression, exp.Select) and ( 108 all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) 109 or _is_limit_1(scope) 110 or not scope.expression.args.get("from") 111 ) 112 113 114def _is_limit_1(scope): 115 limit = scope.expression.args.get("limit") 116 return limit and limit.expression.this == "1" 117 118 119def join_condition(join): 120 """ 121 Extract the join condition from a join expression. 122 123 Args: 124 join (exp.Join) 125 Returns: 126 tuple[list[str], list[str], exp.Expression]: 127 Tuple of (source key, join key, remaining predicate) 128 """ 129 name = join.alias_or_name 130 on = (join.args.get("on") or exp.true()).copy() 131 source_key = [] 132 join_key = [] 133 134 def extract_condition(condition): 135 left, right = condition.unnest_operands() 136 left_tables = exp.column_table_names(left) 137 right_tables = exp.column_table_names(right) 138 139 if name in left_tables and name not in right_tables: 140 join_key.append(left) 141 source_key.append(right) 142 condition.replace(exp.true()) 143 elif name in right_tables and name not in left_tables: 144 join_key.append(right) 145 source_key.append(left) 146 condition.replace(exp.true()) 147 148 # find the join keys 149 # SELECT 150 # FROM x 151 # JOIN y 152 # ON x.a = y.b AND y.b > 1 153 # 154 # should pull y.b as the join key and x.a as the source key 155 if normalized(on): 156 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) 157 158 for condition in on.flatten(): 159 if isinstance(condition, exp.EQ): 160 extract_condition(condition) 161 elif normalized(on, dnf=True): 162 conditions = None 163 164 for condition in on.flatten(): 165 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 166 if conditions is None: 167 conditions = parts 168 else: 169 temp = [] 170 for p in parts: 171 cs = [c for c in conditions if p == c] 172 173 if cs: 174 temp.append(p) 175 temp.extend(cs) 176 conditions = temp 177 178 for condition in conditions: 179 extract_condition(condition) 180 181 return source_key, join_key, on
def
eliminate_joins(expression):
7def eliminate_joins(expression): 8 """ 9 Remove unused joins from an expression. 10 11 This only removes joins when we know that the join condition doesn't produce duplicate rows. 12 13 Example: 14 >>> import sqlglot 15 >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" 16 >>> expression = sqlglot.parse_one(sql) 17 >>> eliminate_joins(expression).sql() 18 'SELECT x.a FROM x' 19 20 Args: 21 expression (sqlglot.Expression): expression to optimize 22 Returns: 23 sqlglot.Expression: optimized expression 24 """ 25 for scope in traverse_scope(expression): 26 # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. 27 # It's probably possible to infer this from the outputs of derived tables. 28 # But for now, let's just skip this rule. 29 if scope.unqualified_columns: 30 continue 31 32 joins = scope.expression.args.get("joins", []) 33 34 # Reverse the joins so we can remove chains of unused joins 35 for join in reversed(joins): 36 alias = join.alias_or_name 37 if _should_eliminate_join(scope, join, alias): 38 join.pop() 39 scope.remove_source(alias) 40 return expression
Remove unused joins from an expression.
This only removes joins when we know that the join condition doesn't produce duplicate rows.
Example:
>>> import sqlglot >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" >>> expression = sqlglot.parse_one(sql) >>> eliminate_joins(expression).sql() 'SELECT x.a FROM x'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
join_condition(join):
120def join_condition(join): 121 """ 122 Extract the join condition from a join expression. 123 124 Args: 125 join (exp.Join) 126 Returns: 127 tuple[list[str], list[str], exp.Expression]: 128 Tuple of (source key, join key, remaining predicate) 129 """ 130 name = join.alias_or_name 131 on = (join.args.get("on") or exp.true()).copy() 132 source_key = [] 133 join_key = [] 134 135 def extract_condition(condition): 136 left, right = condition.unnest_operands() 137 left_tables = exp.column_table_names(left) 138 right_tables = exp.column_table_names(right) 139 140 if name in left_tables and name not in right_tables: 141 join_key.append(left) 142 source_key.append(right) 143 condition.replace(exp.true()) 144 elif name in right_tables and name not in left_tables: 145 join_key.append(right) 146 source_key.append(left) 147 condition.replace(exp.true()) 148 149 # find the join keys 150 # SELECT 151 # FROM x 152 # JOIN y 153 # ON x.a = y.b AND y.b > 1 154 # 155 # should pull y.b as the join key and x.a as the source key 156 if normalized(on): 157 on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) 158 159 for condition in on.flatten(): 160 if isinstance(condition, exp.EQ): 161 extract_condition(condition) 162 elif normalized(on, dnf=True): 163 conditions = None 164 165 for condition in on.flatten(): 166 parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] 167 if conditions is None: 168 conditions = parts 169 else: 170 temp = [] 171 for p in parts: 172 cs = [c for c in conditions if p == c] 173 174 if cs: 175 temp.append(p) 176 temp.extend(cs) 177 conditions = temp 178 179 for condition in conditions: 180 extract_condition(condition) 181 182 return source_key, join_key, on
Extract the join condition from a join expression.
Arguments:
- join (exp.Join)
Returns:
tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate)