sqlglot.optimizer.pushdown_predicates
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot.optimizer.normalize import normalized 7from sqlglot.optimizer.scope import build_scope, find_in_scope 8from sqlglot.optimizer.simplify import simplify 9from sqlglot import Dialect 10 11if t.TYPE_CHECKING: 12 from sqlglot._typing import E 13 from sqlglot.dialects.dialect import DialectType 14 15 16def pushdown_predicates(expression: E, dialect: DialectType = None) -> E: 17 """ 18 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 19 20 Example: 21 >>> import sqlglot 22 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 23 >>> expression = sqlglot.parse_one(sql) 24 >>> pushdown_predicates(expression).sql() 25 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 26 27 Args: 28 expression (sqlglot.Expression): expression to optimize 29 Returns: 30 sqlglot.Expression: optimized expression 31 """ 32 from sqlglot.dialects.athena import Athena 33 from sqlglot.dialects.presto import Presto 34 35 root = build_scope(expression) 36 37 dialect = Dialect.get_or_raise(dialect) 38 unnest_requires_cross_join = isinstance(dialect, (Athena, Presto)) 39 40 if root: 41 scope_ref_count = root.ref_count() 42 43 for scope in reversed(list(root.traverse())): 44 select = scope.expression 45 where = select.args.get("where") 46 if where: 47 selected_sources = scope.selected_sources 48 join_index = { 49 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 50 } 51 52 # a right join can only push down to itself and not the source FROM table 53 # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression 54 pushdown_allowed = True 55 for k, (node, source) in selected_sources.items(): 56 parent = node.find_ancestor(exp.Join, exp.From) 57 if isinstance(parent, exp.Join): 58 if parent.side == "RIGHT": 59 selected_sources = {k: (node, source)} 60 break 61 if isinstance(node, exp.Unnest) and unnest_requires_cross_join: 62 pushdown_allowed = False 63 break 64 65 if pushdown_allowed: 66 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 67 68 # joins should only pushdown into itself, not to other joins 69 # so we limit the selected sources to only itself 70 for join in select.args.get("joins") or []: 71 name = join.alias_or_name 72 if name in scope.selected_sources: 73 pushdown( 74 join.args.get("on"), 75 {name: scope.selected_sources[name]}, 76 scope_ref_count, 77 dialect, 78 ) 79 80 return expression 81 82 83def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 84 if not condition: 85 return 86 87 condition = condition.replace(simplify(condition, dialect=dialect)) 88 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 89 90 predicates = list( 91 condition.flatten() 92 if isinstance(condition, exp.And if cnf_like else exp.Or) 93 else [condition] 94 ) 95 96 if cnf_like: 97 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 98 else: 99 pushdown_dnf(predicates, sources, scope_ref_count) 100 101 102def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 103 """ 104 If the predicates are in CNF like form, we can simply replace each block in the parent. 105 """ 106 join_index = join_index or {} 107 for predicate in predicates: 108 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 109 if isinstance(node, exp.Join): 110 name = node.alias_or_name 111 predicate_tables = exp.column_table_names(predicate, name) 112 113 # Don't push the predicate if it references tables that appear in later joins 114 this_index = join_index[name] 115 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 116 predicate.replace(exp.true()) 117 node.on(predicate, copy=False) 118 break 119 if isinstance(node, exp.Select): 120 predicate.replace(exp.true()) 121 inner_predicate = replace_aliases(node, predicate) 122 if find_in_scope(inner_predicate, exp.AggFunc): 123 node.having(inner_predicate, copy=False) 124 else: 125 node.where(inner_predicate, copy=False) 126 127 128def pushdown_dnf(predicates, sources, scope_ref_count): 129 """ 130 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 131 Additionally, we can't remove predicates from their original form. 132 """ 133 # find all the tables that can be pushdown too 134 # these are tables that are referenced in all blocks of a DNF 135 # (a.x AND b.x) OR (a.y AND c.y) 136 # only table a can be push down 137 pushdown_tables = set() 138 139 for a in predicates: 140 a_tables = exp.column_table_names(a) 141 142 for b in predicates: 143 a_tables &= exp.column_table_names(b) 144 145 pushdown_tables.update(a_tables) 146 147 conditions = {} 148 149 # pushdown all predicates to their respective nodes 150 for table in sorted(pushdown_tables): 151 for predicate in predicates: 152 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 153 154 if table not in nodes: 155 continue 156 157 conditions[table] = ( 158 exp.or_(conditions[table], predicate) if table in conditions else predicate 159 ) 160 161 for name, node in nodes.items(): 162 if name not in conditions: 163 continue 164 165 predicate = conditions[name] 166 167 if isinstance(node, exp.Join): 168 node.on(predicate, copy=False) 169 elif isinstance(node, exp.Select): 170 inner_predicate = replace_aliases(node, predicate) 171 if find_in_scope(inner_predicate, exp.AggFunc): 172 node.having(inner_predicate, copy=False) 173 else: 174 node.where(inner_predicate, copy=False) 175 176 177def nodes_for_predicate(predicate, sources, scope_ref_count): 178 nodes = {} 179 tables = exp.column_table_names(predicate) 180 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 181 182 for table in sorted(tables): 183 node, source = sources.get(table) or (None, None) 184 185 # if the predicate is in a where statement we can try to push it down 186 # we want to find the root join or from statement 187 if node and where_condition: 188 node = node.find_ancestor(exp.Join, exp.From) 189 190 # a node can reference a CTE which should be pushed down 191 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 192 with_ = source.parent.expression.args.get("with_") 193 if with_ and with_.recursive: 194 return {} 195 node = source.expression 196 197 if isinstance(node, exp.Join): 198 if node.side and node.side != "RIGHT": 199 return {} 200 nodes[table] = node 201 elif isinstance(node, exp.Select) and len(tables) == 1: 202 # We can't push down window expressions 203 has_window_expression = any( 204 select for select in node.selects if select.find(exp.Window) 205 ) 206 # we can't push down predicates to select statements if they are referenced in 207 # multiple places. 208 if ( 209 not node.args.get("group") 210 and scope_ref_count[id(source)] < 2 211 and not has_window_expression 212 ): 213 nodes[table] = node 214 return nodes 215 216 217def replace_aliases(source, predicate): 218 aliases = {} 219 220 for select in source.selects: 221 if isinstance(select, exp.Alias): 222 aliases[select.alias] = select.this 223 else: 224 aliases[select.name] = select 225 226 def _replace_alias(column): 227 if isinstance(column, exp.Column) and column.name in aliases: 228 return aliases[column.name].copy() 229 return column 230 231 return predicate.transform(_replace_alias)
def
pushdown_predicates( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
17def pushdown_predicates(expression: E, dialect: DialectType = None) -> E: 18 """ 19 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 20 21 Example: 22 >>> import sqlglot 23 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 24 >>> expression = sqlglot.parse_one(sql) 25 >>> pushdown_predicates(expression).sql() 26 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 27 28 Args: 29 expression (sqlglot.Expression): expression to optimize 30 Returns: 31 sqlglot.Expression: optimized expression 32 """ 33 from sqlglot.dialects.athena import Athena 34 from sqlglot.dialects.presto import Presto 35 36 root = build_scope(expression) 37 38 dialect = Dialect.get_or_raise(dialect) 39 unnest_requires_cross_join = isinstance(dialect, (Athena, Presto)) 40 41 if root: 42 scope_ref_count = root.ref_count() 43 44 for scope in reversed(list(root.traverse())): 45 select = scope.expression 46 where = select.args.get("where") 47 if where: 48 selected_sources = scope.selected_sources 49 join_index = { 50 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 51 } 52 53 # a right join can only push down to itself and not the source FROM table 54 # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression 55 pushdown_allowed = True 56 for k, (node, source) in selected_sources.items(): 57 parent = node.find_ancestor(exp.Join, exp.From) 58 if isinstance(parent, exp.Join): 59 if parent.side == "RIGHT": 60 selected_sources = {k: (node, source)} 61 break 62 if isinstance(node, exp.Unnest) and unnest_requires_cross_join: 63 pushdown_allowed = False 64 break 65 66 if pushdown_allowed: 67 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 68 69 # joins should only pushdown into itself, not to other joins 70 # so we limit the selected sources to only itself 71 for join in select.args.get("joins") or []: 72 name = join.alias_or_name 73 if name in scope.selected_sources: 74 pushdown( 75 join.args.get("on"), 76 {name: scope.selected_sources[name]}, 77 scope_ref_count, 78 dialect, 79 ) 80 81 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
84def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 85 if not condition: 86 return 87 88 condition = condition.replace(simplify(condition, dialect=dialect)) 89 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 90 91 predicates = list( 92 condition.flatten() 93 if isinstance(condition, exp.And if cnf_like else exp.Or) 94 else [condition] 95 ) 96 97 if cnf_like: 98 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 99 else: 100 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
103def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 104 """ 105 If the predicates are in CNF like form, we can simply replace each block in the parent. 106 """ 107 join_index = join_index or {} 108 for predicate in predicates: 109 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 110 if isinstance(node, exp.Join): 111 name = node.alias_or_name 112 predicate_tables = exp.column_table_names(predicate, name) 113 114 # Don't push the predicate if it references tables that appear in later joins 115 this_index = join_index[name] 116 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 117 predicate.replace(exp.true()) 118 node.on(predicate, copy=False) 119 break 120 if isinstance(node, exp.Select): 121 predicate.replace(exp.true()) 122 inner_predicate = replace_aliases(node, predicate) 123 if find_in_scope(inner_predicate, exp.AggFunc): 124 node.having(inner_predicate, copy=False) 125 else: 126 node.where(inner_predicate, copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, sources, scope_ref_count):
129def pushdown_dnf(predicates, sources, scope_ref_count): 130 """ 131 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 132 Additionally, we can't remove predicates from their original form. 133 """ 134 # find all the tables that can be pushdown too 135 # these are tables that are referenced in all blocks of a DNF 136 # (a.x AND b.x) OR (a.y AND c.y) 137 # only table a can be push down 138 pushdown_tables = set() 139 140 for a in predicates: 141 a_tables = exp.column_table_names(a) 142 143 for b in predicates: 144 a_tables &= exp.column_table_names(b) 145 146 pushdown_tables.update(a_tables) 147 148 conditions = {} 149 150 # pushdown all predicates to their respective nodes 151 for table in sorted(pushdown_tables): 152 for predicate in predicates: 153 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 154 155 if table not in nodes: 156 continue 157 158 conditions[table] = ( 159 exp.or_(conditions[table], predicate) if table in conditions else predicate 160 ) 161 162 for name, node in nodes.items(): 163 if name not in conditions: 164 continue 165 166 predicate = conditions[name] 167 168 if isinstance(node, exp.Join): 169 node.on(predicate, copy=False) 170 elif isinstance(node, exp.Select): 171 inner_predicate = replace_aliases(node, predicate) 172 if find_in_scope(inner_predicate, exp.AggFunc): 173 node.having(inner_predicate, copy=False) 174 else: 175 node.where(inner_predicate, copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
178def nodes_for_predicate(predicate, sources, scope_ref_count): 179 nodes = {} 180 tables = exp.column_table_names(predicate) 181 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 182 183 for table in sorted(tables): 184 node, source = sources.get(table) or (None, None) 185 186 # if the predicate is in a where statement we can try to push it down 187 # we want to find the root join or from statement 188 if node and where_condition: 189 node = node.find_ancestor(exp.Join, exp.From) 190 191 # a node can reference a CTE which should be pushed down 192 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 193 with_ = source.parent.expression.args.get("with_") 194 if with_ and with_.recursive: 195 return {} 196 node = source.expression 197 198 if isinstance(node, exp.Join): 199 if node.side and node.side != "RIGHT": 200 return {} 201 nodes[table] = node 202 elif isinstance(node, exp.Select) and len(tables) == 1: 203 # We can't push down window expressions 204 has_window_expression = any( 205 select for select in node.selects if select.find(exp.Window) 206 ) 207 # we can't push down predicates to select statements if they are referenced in 208 # multiple places. 209 if ( 210 not node.args.get("group") 211 and scope_ref_count[id(source)] < 2 212 and not has_window_expression 213 ): 214 nodes[table] = node 215 return nodes
def
replace_aliases(source, predicate):
218def replace_aliases(source, predicate): 219 aliases = {} 220 221 for select in source.selects: 222 if isinstance(select, exp.Alias): 223 aliases[select.alias] = select.this 224 else: 225 aliases[select.name] = select 226 227 def _replace_alias(column): 228 if isinstance(column, exp.Column) and column.name in aliases: 229 return aliases[column.name].copy() 230 return column 231 232 return predicate.transform(_replace_alias)