sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope, find_in_scope 4from sqlglot.optimizer.simplify import simplify 5from sqlglot import Dialect 6 7 8def pushdown_predicates(expression, dialect=None): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 from sqlglot.dialects.presto import Presto 25 26 root = build_scope(expression) 27 28 dialect = Dialect.get_or_raise(dialect) 29 unnest_requires_cross_join = isinstance(dialect, Presto) 30 31 if root: 32 scope_ref_count = root.ref_count() 33 34 for scope in reversed(list(root.traverse())): 35 select = scope.expression 36 where = select.args.get("where") 37 if where: 38 selected_sources = scope.selected_sources 39 join_index = { 40 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 41 } 42 43 # a right join can only push down to itself and not the source FROM table 44 # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression 45 pushdown_allowed = True 46 for k, (node, source) in selected_sources.items(): 47 parent = node.find_ancestor(exp.Join, exp.From) 48 if isinstance(parent, exp.Join): 49 if parent.side == "RIGHT": 50 selected_sources = {k: (node, source)} 51 break 52 if isinstance(node, exp.Unnest) and unnest_requires_cross_join: 53 pushdown_allowed = False 54 break 55 56 if pushdown_allowed: 57 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 58 59 # joins should only pushdown into itself, not to other joins 60 # so we limit the selected sources to only itself 61 for join in select.args.get("joins") or []: 62 name = join.alias_or_name 63 if name in scope.selected_sources: 64 pushdown( 65 join.args.get("on"), 66 {name: scope.selected_sources[name]}, 67 scope_ref_count, 68 dialect, 69 ) 70 71 return expression 72 73 74def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 75 if not condition: 76 return 77 78 condition = condition.replace(simplify(condition, dialect=dialect)) 79 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 80 81 predicates = list( 82 condition.flatten() 83 if isinstance(condition, exp.And if cnf_like else exp.Or) 84 else [condition] 85 ) 86 87 if cnf_like: 88 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 89 else: 90 pushdown_dnf(predicates, sources, scope_ref_count) 91 92 93def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 94 """ 95 If the predicates are in CNF like form, we can simply replace each block in the parent. 96 """ 97 join_index = join_index or {} 98 for predicate in predicates: 99 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 100 if isinstance(node, exp.Join): 101 name = node.alias_or_name 102 predicate_tables = exp.column_table_names(predicate, name) 103 104 # Don't push the predicate if it references tables that appear in later joins 105 this_index = join_index[name] 106 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 107 predicate.replace(exp.true()) 108 node.on(predicate, copy=False) 109 break 110 if isinstance(node, exp.Select): 111 predicate.replace(exp.true()) 112 inner_predicate = replace_aliases(node, predicate) 113 if find_in_scope(inner_predicate, exp.AggFunc): 114 node.having(inner_predicate, copy=False) 115 else: 116 node.where(inner_predicate, copy=False) 117 118 119def pushdown_dnf(predicates, sources, scope_ref_count): 120 """ 121 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 122 Additionally, we can't remove predicates from their original form. 123 """ 124 # find all the tables that can be pushdown too 125 # these are tables that are referenced in all blocks of a DNF 126 # (a.x AND b.x) OR (a.y AND c.y) 127 # only table a can be push down 128 pushdown_tables = set() 129 130 for a in predicates: 131 a_tables = exp.column_table_names(a) 132 133 for b in predicates: 134 a_tables &= exp.column_table_names(b) 135 136 pushdown_tables.update(a_tables) 137 138 conditions = {} 139 140 # pushdown all predicates to their respective nodes 141 for table in sorted(pushdown_tables): 142 for predicate in predicates: 143 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 144 145 if table not in nodes: 146 continue 147 148 conditions[table] = ( 149 exp.or_(conditions[table], predicate) if table in conditions else predicate 150 ) 151 152 for name, node in nodes.items(): 153 if name not in conditions: 154 continue 155 156 predicate = conditions[name] 157 158 if isinstance(node, exp.Join): 159 node.on(predicate, copy=False) 160 elif isinstance(node, exp.Select): 161 inner_predicate = replace_aliases(node, predicate) 162 if find_in_scope(inner_predicate, exp.AggFunc): 163 node.having(inner_predicate, copy=False) 164 else: 165 node.where(inner_predicate, copy=False) 166 167 168def nodes_for_predicate(predicate, sources, scope_ref_count): 169 nodes = {} 170 tables = exp.column_table_names(predicate) 171 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 172 173 for table in sorted(tables): 174 node, source = sources.get(table) or (None, None) 175 176 # if the predicate is in a where statement we can try to push it down 177 # we want to find the root join or from statement 178 if node and where_condition: 179 node = node.find_ancestor(exp.Join, exp.From) 180 181 # a node can reference a CTE which should be pushed down 182 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 183 with_ = source.parent.expression.args.get("with") 184 if with_ and with_.recursive: 185 return {} 186 node = source.expression 187 188 if isinstance(node, exp.Join): 189 if node.side and node.side != "RIGHT": 190 return {} 191 nodes[table] = node 192 elif isinstance(node, exp.Select) and len(tables) == 1: 193 # We can't push down window expressions 194 has_window_expression = any( 195 select for select in node.selects if select.find(exp.Window) 196 ) 197 # we can't push down predicates to select statements if they are referenced in 198 # multiple places. 199 if ( 200 not node.args.get("group") 201 and scope_ref_count[id(source)] < 2 202 and not has_window_expression 203 ): 204 nodes[table] = node 205 return nodes 206 207 208def replace_aliases(source, predicate): 209 aliases = {} 210 211 for select in source.selects: 212 if isinstance(select, exp.Alias): 213 aliases[select.alias] = select.this 214 else: 215 aliases[select.name] = select 216 217 def _replace_alias(column): 218 if isinstance(column, exp.Column) and column.name in aliases: 219 return aliases[column.name].copy() 220 return column 221 222 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression, dialect=None):
9def pushdown_predicates(expression, dialect=None): 10 """ 11 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 12 13 Example: 14 >>> import sqlglot 15 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 16 >>> expression = sqlglot.parse_one(sql) 17 >>> pushdown_predicates(expression).sql() 18 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 19 20 Args: 21 expression (sqlglot.Expression): expression to optimize 22 Returns: 23 sqlglot.Expression: optimized expression 24 """ 25 from sqlglot.dialects.presto import Presto 26 27 root = build_scope(expression) 28 29 dialect = Dialect.get_or_raise(dialect) 30 unnest_requires_cross_join = isinstance(dialect, Presto) 31 32 if root: 33 scope_ref_count = root.ref_count() 34 35 for scope in reversed(list(root.traverse())): 36 select = scope.expression 37 where = select.args.get("where") 38 if where: 39 selected_sources = scope.selected_sources 40 join_index = { 41 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 42 } 43 44 # a right join can only push down to itself and not the source FROM table 45 # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression 46 pushdown_allowed = True 47 for k, (node, source) in selected_sources.items(): 48 parent = node.find_ancestor(exp.Join, exp.From) 49 if isinstance(parent, exp.Join): 50 if parent.side == "RIGHT": 51 selected_sources = {k: (node, source)} 52 break 53 if isinstance(node, exp.Unnest) and unnest_requires_cross_join: 54 pushdown_allowed = False 55 break 56 57 if pushdown_allowed: 58 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 59 60 # joins should only pushdown into itself, not to other joins 61 # so we limit the selected sources to only itself 62 for join in select.args.get("joins") or []: 63 name = join.alias_or_name 64 if name in scope.selected_sources: 65 pushdown( 66 join.args.get("on"), 67 {name: scope.selected_sources[name]}, 68 scope_ref_count, 69 dialect, 70 ) 71 72 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
75def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 76 if not condition: 77 return 78 79 condition = condition.replace(simplify(condition, dialect=dialect)) 80 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 81 82 predicates = list( 83 condition.flatten() 84 if isinstance(condition, exp.And if cnf_like else exp.Or) 85 else [condition] 86 ) 87 88 if cnf_like: 89 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 90 else: 91 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
94def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 95 """ 96 If the predicates are in CNF like form, we can simply replace each block in the parent. 97 """ 98 join_index = join_index or {} 99 for predicate in predicates: 100 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 101 if isinstance(node, exp.Join): 102 name = node.alias_or_name 103 predicate_tables = exp.column_table_names(predicate, name) 104 105 # Don't push the predicate if it references tables that appear in later joins 106 this_index = join_index[name] 107 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 108 predicate.replace(exp.true()) 109 node.on(predicate, copy=False) 110 break 111 if isinstance(node, exp.Select): 112 predicate.replace(exp.true()) 113 inner_predicate = replace_aliases(node, predicate) 114 if find_in_scope(inner_predicate, exp.AggFunc): 115 node.having(inner_predicate, copy=False) 116 else: 117 node.where(inner_predicate, copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, sources, scope_ref_count):
120def pushdown_dnf(predicates, sources, scope_ref_count): 121 """ 122 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 123 Additionally, we can't remove predicates from their original form. 124 """ 125 # find all the tables that can be pushdown too 126 # these are tables that are referenced in all blocks of a DNF 127 # (a.x AND b.x) OR (a.y AND c.y) 128 # only table a can be push down 129 pushdown_tables = set() 130 131 for a in predicates: 132 a_tables = exp.column_table_names(a) 133 134 for b in predicates: 135 a_tables &= exp.column_table_names(b) 136 137 pushdown_tables.update(a_tables) 138 139 conditions = {} 140 141 # pushdown all predicates to their respective nodes 142 for table in sorted(pushdown_tables): 143 for predicate in predicates: 144 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 145 146 if table not in nodes: 147 continue 148 149 conditions[table] = ( 150 exp.or_(conditions[table], predicate) if table in conditions else predicate 151 ) 152 153 for name, node in nodes.items(): 154 if name not in conditions: 155 continue 156 157 predicate = conditions[name] 158 159 if isinstance(node, exp.Join): 160 node.on(predicate, copy=False) 161 elif isinstance(node, exp.Select): 162 inner_predicate = replace_aliases(node, predicate) 163 if find_in_scope(inner_predicate, exp.AggFunc): 164 node.having(inner_predicate, copy=False) 165 else: 166 node.where(inner_predicate, copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
169def nodes_for_predicate(predicate, sources, scope_ref_count): 170 nodes = {} 171 tables = exp.column_table_names(predicate) 172 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 173 174 for table in sorted(tables): 175 node, source = sources.get(table) or (None, None) 176 177 # if the predicate is in a where statement we can try to push it down 178 # we want to find the root join or from statement 179 if node and where_condition: 180 node = node.find_ancestor(exp.Join, exp.From) 181 182 # a node can reference a CTE which should be pushed down 183 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 184 with_ = source.parent.expression.args.get("with") 185 if with_ and with_.recursive: 186 return {} 187 node = source.expression 188 189 if isinstance(node, exp.Join): 190 if node.side and node.side != "RIGHT": 191 return {} 192 nodes[table] = node 193 elif isinstance(node, exp.Select) and len(tables) == 1: 194 # We can't push down window expressions 195 has_window_expression = any( 196 select for select in node.selects if select.find(exp.Window) 197 ) 198 # we can't push down predicates to select statements if they are referenced in 199 # multiple places. 200 if ( 201 not node.args.get("group") 202 and scope_ref_count[id(source)] < 2 203 and not has_window_expression 204 ): 205 nodes[table] = node 206 return nodes
def
replace_aliases(source, predicate):
209def replace_aliases(source, predicate): 210 aliases = {} 211 212 for select in source.selects: 213 if isinstance(select, exp.Alias): 214 aliases[select.alias] = select.this 215 else: 216 aliases[select.name] = select 217 218 def _replace_alias(column): 219 if isinstance(column, exp.Column) and column.name in aliases: 220 return aliases[column.name].copy() 221 return column 222 223 return predicate.transform(_replace_alias)