sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope, find_in_scope 4from sqlglot.optimizer.simplify import simplify 5from sqlglot import Dialect 6 7 8def pushdown_predicates(expression, dialect=None): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 from sqlglot.dialects.athena import Athena 25 from sqlglot.dialects.presto import Presto 26 27 root = build_scope(expression) 28 29 dialect = Dialect.get_or_raise(dialect) 30 unnest_requires_cross_join = isinstance(dialect, (Athena, Presto)) 31 32 if root: 33 scope_ref_count = root.ref_count() 34 35 for scope in reversed(list(root.traverse())): 36 select = scope.expression 37 where = select.args.get("where") 38 if where: 39 selected_sources = scope.selected_sources 40 join_index = { 41 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 42 } 43 44 # a right join can only push down to itself and not the source FROM table 45 # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression 46 pushdown_allowed = True 47 for k, (node, source) in selected_sources.items(): 48 parent = node.find_ancestor(exp.Join, exp.From) 49 if isinstance(parent, exp.Join): 50 if parent.side == "RIGHT": 51 selected_sources = {k: (node, source)} 52 break 53 if isinstance(node, exp.Unnest) and unnest_requires_cross_join: 54 pushdown_allowed = False 55 break 56 57 if pushdown_allowed: 58 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 59 60 # joins should only pushdown into itself, not to other joins 61 # so we limit the selected sources to only itself 62 for join in select.args.get("joins") or []: 63 name = join.alias_or_name 64 if name in scope.selected_sources: 65 pushdown( 66 join.args.get("on"), 67 {name: scope.selected_sources[name]}, 68 scope_ref_count, 69 dialect, 70 ) 71 72 return expression 73 74 75def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 76 if not condition: 77 return 78 79 condition = condition.replace(simplify(condition, dialect=dialect)) 80 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 81 82 predicates = list( 83 condition.flatten() 84 if isinstance(condition, exp.And if cnf_like else exp.Or) 85 else [condition] 86 ) 87 88 if cnf_like: 89 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 90 else: 91 pushdown_dnf(predicates, sources, scope_ref_count) 92 93 94def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 95 """ 96 If the predicates are in CNF like form, we can simply replace each block in the parent. 97 """ 98 join_index = join_index or {} 99 for predicate in predicates: 100 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 101 if isinstance(node, exp.Join): 102 name = node.alias_or_name 103 predicate_tables = exp.column_table_names(predicate, name) 104 105 # Don't push the predicate if it references tables that appear in later joins 106 this_index = join_index[name] 107 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 108 predicate.replace(exp.true()) 109 node.on(predicate, copy=False) 110 break 111 if isinstance(node, exp.Select): 112 predicate.replace(exp.true()) 113 inner_predicate = replace_aliases(node, predicate) 114 if find_in_scope(inner_predicate, exp.AggFunc): 115 node.having(inner_predicate, copy=False) 116 else: 117 node.where(inner_predicate, copy=False) 118 119 120def pushdown_dnf(predicates, sources, scope_ref_count): 121 """ 122 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 123 Additionally, we can't remove predicates from their original form. 124 """ 125 # find all the tables that can be pushdown too 126 # these are tables that are referenced in all blocks of a DNF 127 # (a.x AND b.x) OR (a.y AND c.y) 128 # only table a can be push down 129 pushdown_tables = set() 130 131 for a in predicates: 132 a_tables = exp.column_table_names(a) 133 134 for b in predicates: 135 a_tables &= exp.column_table_names(b) 136 137 pushdown_tables.update(a_tables) 138 139 conditions = {} 140 141 # pushdown all predicates to their respective nodes 142 for table in sorted(pushdown_tables): 143 for predicate in predicates: 144 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 145 146 if table not in nodes: 147 continue 148 149 conditions[table] = ( 150 exp.or_(conditions[table], predicate) if table in conditions else predicate 151 ) 152 153 for name, node in nodes.items(): 154 if name not in conditions: 155 continue 156 157 predicate = conditions[name] 158 159 if isinstance(node, exp.Join): 160 node.on(predicate, copy=False) 161 elif isinstance(node, exp.Select): 162 inner_predicate = replace_aliases(node, predicate) 163 if find_in_scope(inner_predicate, exp.AggFunc): 164 node.having(inner_predicate, copy=False) 165 else: 166 node.where(inner_predicate, copy=False) 167 168 169def nodes_for_predicate(predicate, sources, scope_ref_count): 170 nodes = {} 171 tables = exp.column_table_names(predicate) 172 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 173 174 for table in sorted(tables): 175 node, source = sources.get(table) or (None, None) 176 177 # if the predicate is in a where statement we can try to push it down 178 # we want to find the root join or from statement 179 if node and where_condition: 180 node = node.find_ancestor(exp.Join, exp.From) 181 182 # a node can reference a CTE which should be pushed down 183 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 184 with_ = source.parent.expression.args.get("with") 185 if with_ and with_.recursive: 186 return {} 187 node = source.expression 188 189 if isinstance(node, exp.Join): 190 if node.side and node.side != "RIGHT": 191 return {} 192 nodes[table] = node 193 elif isinstance(node, exp.Select) and len(tables) == 1: 194 # We can't push down window expressions 195 has_window_expression = any( 196 select for select in node.selects if select.find(exp.Window) 197 ) 198 # we can't push down predicates to select statements if they are referenced in 199 # multiple places. 200 if ( 201 not node.args.get("group") 202 and scope_ref_count[id(source)] < 2 203 and not has_window_expression 204 ): 205 nodes[table] = node 206 return nodes 207 208 209def replace_aliases(source, predicate): 210 aliases = {} 211 212 for select in source.selects: 213 if isinstance(select, exp.Alias): 214 aliases[select.alias] = select.this 215 else: 216 aliases[select.name] = select 217 218 def _replace_alias(column): 219 if isinstance(column, exp.Column) and column.name in aliases: 220 return aliases[column.name].copy() 221 return column 222 223 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression, dialect=None):
9def pushdown_predicates(expression, dialect=None): 10 """ 11 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 12 13 Example: 14 >>> import sqlglot 15 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 16 >>> expression = sqlglot.parse_one(sql) 17 >>> pushdown_predicates(expression).sql() 18 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 19 20 Args: 21 expression (sqlglot.Expression): expression to optimize 22 Returns: 23 sqlglot.Expression: optimized expression 24 """ 25 from sqlglot.dialects.athena import Athena 26 from sqlglot.dialects.presto import Presto 27 28 root = build_scope(expression) 29 30 dialect = Dialect.get_or_raise(dialect) 31 unnest_requires_cross_join = isinstance(dialect, (Athena, Presto)) 32 33 if root: 34 scope_ref_count = root.ref_count() 35 36 for scope in reversed(list(root.traverse())): 37 select = scope.expression 38 where = select.args.get("where") 39 if where: 40 selected_sources = scope.selected_sources 41 join_index = { 42 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 43 } 44 45 # a right join can only push down to itself and not the source FROM table 46 # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression 47 pushdown_allowed = True 48 for k, (node, source) in selected_sources.items(): 49 parent = node.find_ancestor(exp.Join, exp.From) 50 if isinstance(parent, exp.Join): 51 if parent.side == "RIGHT": 52 selected_sources = {k: (node, source)} 53 break 54 if isinstance(node, exp.Unnest) and unnest_requires_cross_join: 55 pushdown_allowed = False 56 break 57 58 if pushdown_allowed: 59 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 60 61 # joins should only pushdown into itself, not to other joins 62 # so we limit the selected sources to only itself 63 for join in select.args.get("joins") or []: 64 name = join.alias_or_name 65 if name in scope.selected_sources: 66 pushdown( 67 join.args.get("on"), 68 {name: scope.selected_sources[name]}, 69 scope_ref_count, 70 dialect, 71 ) 72 73 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
76def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 77 if not condition: 78 return 79 80 condition = condition.replace(simplify(condition, dialect=dialect)) 81 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 82 83 predicates = list( 84 condition.flatten() 85 if isinstance(condition, exp.And if cnf_like else exp.Or) 86 else [condition] 87 ) 88 89 if cnf_like: 90 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 91 else: 92 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
95def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 96 """ 97 If the predicates are in CNF like form, we can simply replace each block in the parent. 98 """ 99 join_index = join_index or {} 100 for predicate in predicates: 101 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 102 if isinstance(node, exp.Join): 103 name = node.alias_or_name 104 predicate_tables = exp.column_table_names(predicate, name) 105 106 # Don't push the predicate if it references tables that appear in later joins 107 this_index = join_index[name] 108 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 109 predicate.replace(exp.true()) 110 node.on(predicate, copy=False) 111 break 112 if isinstance(node, exp.Select): 113 predicate.replace(exp.true()) 114 inner_predicate = replace_aliases(node, predicate) 115 if find_in_scope(inner_predicate, exp.AggFunc): 116 node.having(inner_predicate, copy=False) 117 else: 118 node.where(inner_predicate, copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, sources, scope_ref_count):
121def pushdown_dnf(predicates, sources, scope_ref_count): 122 """ 123 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 124 Additionally, we can't remove predicates from their original form. 125 """ 126 # find all the tables that can be pushdown too 127 # these are tables that are referenced in all blocks of a DNF 128 # (a.x AND b.x) OR (a.y AND c.y) 129 # only table a can be push down 130 pushdown_tables = set() 131 132 for a in predicates: 133 a_tables = exp.column_table_names(a) 134 135 for b in predicates: 136 a_tables &= exp.column_table_names(b) 137 138 pushdown_tables.update(a_tables) 139 140 conditions = {} 141 142 # pushdown all predicates to their respective nodes 143 for table in sorted(pushdown_tables): 144 for predicate in predicates: 145 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 146 147 if table not in nodes: 148 continue 149 150 conditions[table] = ( 151 exp.or_(conditions[table], predicate) if table in conditions else predicate 152 ) 153 154 for name, node in nodes.items(): 155 if name not in conditions: 156 continue 157 158 predicate = conditions[name] 159 160 if isinstance(node, exp.Join): 161 node.on(predicate, copy=False) 162 elif isinstance(node, exp.Select): 163 inner_predicate = replace_aliases(node, predicate) 164 if find_in_scope(inner_predicate, exp.AggFunc): 165 node.having(inner_predicate, copy=False) 166 else: 167 node.where(inner_predicate, copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
170def nodes_for_predicate(predicate, sources, scope_ref_count): 171 nodes = {} 172 tables = exp.column_table_names(predicate) 173 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 174 175 for table in sorted(tables): 176 node, source = sources.get(table) or (None, None) 177 178 # if the predicate is in a where statement we can try to push it down 179 # we want to find the root join or from statement 180 if node and where_condition: 181 node = node.find_ancestor(exp.Join, exp.From) 182 183 # a node can reference a CTE which should be pushed down 184 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 185 with_ = source.parent.expression.args.get("with") 186 if with_ and with_.recursive: 187 return {} 188 node = source.expression 189 190 if isinstance(node, exp.Join): 191 if node.side and node.side != "RIGHT": 192 return {} 193 nodes[table] = node 194 elif isinstance(node, exp.Select) and len(tables) == 1: 195 # We can't push down window expressions 196 has_window_expression = any( 197 select for select in node.selects if select.find(exp.Window) 198 ) 199 # we can't push down predicates to select statements if they are referenced in 200 # multiple places. 201 if ( 202 not node.args.get("group") 203 and scope_ref_count[id(source)] < 2 204 and not has_window_expression 205 ): 206 nodes[table] = node 207 return nodes
def
replace_aliases(source, predicate):
210def replace_aliases(source, predicate): 211 aliases = {} 212 213 for select in source.selects: 214 if isinstance(select, exp.Alias): 215 aliases[select.alias] = select.this 216 else: 217 aliases[select.name] = select 218 219 def _replace_alias(column): 220 if isinstance(column, exp.Column) and column.name in aliases: 221 return aliases[column.name].copy() 222 return column 223 224 return predicate.transform(_replace_alias)