sqlglot.optimizer.pushdown_predicates
1from sqlglot import exp 2from sqlglot.optimizer.normalize import normalized 3from sqlglot.optimizer.scope import build_scope, find_in_scope 4from sqlglot.optimizer.simplify import simplify 5 6 7def pushdown_predicates(expression, dialect=None): 8 """ 9 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 10 11 Example: 12 >>> import sqlglot 13 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 14 >>> expression = sqlglot.parse_one(sql) 15 >>> pushdown_predicates(expression).sql() 16 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 17 18 Args: 19 expression (sqlglot.Expression): expression to optimize 20 Returns: 21 sqlglot.Expression: optimized expression 22 """ 23 root = build_scope(expression) 24 25 if root: 26 scope_ref_count = root.ref_count() 27 28 for scope in reversed(list(root.traverse())): 29 select = scope.expression 30 where = select.args.get("where") 31 if where: 32 selected_sources = scope.selected_sources 33 join_index = { 34 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 35 } 36 37 # a right join can only push down to itself and not the source FROM table 38 for k, (node, source) in selected_sources.items(): 39 parent = node.find_ancestor(exp.Join, exp.From) 40 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 41 selected_sources = {k: (node, source)} 42 break 43 44 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 45 46 # joins should only pushdown into itself, not to other joins 47 # so we limit the selected sources to only itself 48 for join in select.args.get("joins") or []: 49 name = join.alias_or_name 50 if name in scope.selected_sources: 51 pushdown( 52 join.args.get("on"), 53 {name: scope.selected_sources[name]}, 54 scope_ref_count, 55 dialect, 56 ) 57 58 return expression 59 60 61def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 62 if not condition: 63 return 64 65 condition = condition.replace(simplify(condition, dialect=dialect)) 66 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 67 68 predicates = list( 69 condition.flatten() 70 if isinstance(condition, exp.And if cnf_like else exp.Or) 71 else [condition] 72 ) 73 74 if cnf_like: 75 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 76 else: 77 pushdown_dnf(predicates, sources, scope_ref_count) 78 79 80def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 81 """ 82 If the predicates are in CNF like form, we can simply replace each block in the parent. 83 """ 84 join_index = join_index or {} 85 for predicate in predicates: 86 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 87 if isinstance(node, exp.Join): 88 name = node.alias_or_name 89 predicate_tables = exp.column_table_names(predicate, name) 90 91 # Don't push the predicate if it references tables that appear in later joins 92 this_index = join_index[name] 93 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 94 predicate.replace(exp.true()) 95 node.on(predicate, copy=False) 96 break 97 if isinstance(node, exp.Select): 98 predicate.replace(exp.true()) 99 inner_predicate = replace_aliases(node, predicate) 100 if find_in_scope(inner_predicate, exp.AggFunc): 101 node.having(inner_predicate, copy=False) 102 else: 103 node.where(inner_predicate, copy=False) 104 105 106def pushdown_dnf(predicates, sources, scope_ref_count): 107 """ 108 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 109 Additionally, we can't remove predicates from their original form. 110 """ 111 # find all the tables that can be pushdown too 112 # these are tables that are referenced in all blocks of a DNF 113 # (a.x AND b.x) OR (a.y AND c.y) 114 # only table a can be push down 115 pushdown_tables = set() 116 117 for a in predicates: 118 a_tables = exp.column_table_names(a) 119 120 for b in predicates: 121 a_tables &= exp.column_table_names(b) 122 123 pushdown_tables.update(a_tables) 124 125 conditions = {} 126 127 # pushdown all predicates to their respective nodes 128 for table in sorted(pushdown_tables): 129 for predicate in predicates: 130 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 131 132 if table not in nodes: 133 continue 134 135 conditions[table] = ( 136 exp.or_(conditions[table], predicate) if table in conditions else predicate 137 ) 138 139 for name, node in nodes.items(): 140 if name not in conditions: 141 continue 142 143 predicate = conditions[name] 144 145 if isinstance(node, exp.Join): 146 node.on(predicate, copy=False) 147 elif isinstance(node, exp.Select): 148 inner_predicate = replace_aliases(node, predicate) 149 if find_in_scope(inner_predicate, exp.AggFunc): 150 node.having(inner_predicate, copy=False) 151 else: 152 node.where(inner_predicate, copy=False) 153 154 155def nodes_for_predicate(predicate, sources, scope_ref_count): 156 nodes = {} 157 tables = exp.column_table_names(predicate) 158 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 159 160 for table in sorted(tables): 161 node, source = sources.get(table) or (None, None) 162 163 # if the predicate is in a where statement we can try to push it down 164 # we want to find the root join or from statement 165 if node and where_condition: 166 node = node.find_ancestor(exp.Join, exp.From) 167 168 # a node can reference a CTE which should be pushed down 169 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 170 with_ = source.parent.expression.args.get("with") 171 if with_ and with_.recursive: 172 return {} 173 node = source.expression 174 175 if isinstance(node, exp.Join): 176 if node.side and node.side != "RIGHT": 177 return {} 178 nodes[table] = node 179 elif isinstance(node, exp.Select) and len(tables) == 1: 180 # We can't push down window expressions 181 has_window_expression = any( 182 select for select in node.selects if select.find(exp.Window) 183 ) 184 # we can't push down predicates to select statements if they are referenced in 185 # multiple places. 186 if ( 187 not node.args.get("group") 188 and scope_ref_count[id(source)] < 2 189 and not has_window_expression 190 ): 191 nodes[table] = node 192 return nodes 193 194 195def replace_aliases(source, predicate): 196 aliases = {} 197 198 for select in source.selects: 199 if isinstance(select, exp.Alias): 200 aliases[select.alias] = select.this 201 else: 202 aliases[select.name] = select 203 204 def _replace_alias(column): 205 if isinstance(column, exp.Column) and column.name in aliases: 206 return aliases[column.name].copy() 207 return column 208 209 return predicate.transform(_replace_alias)
def
pushdown_predicates(expression, dialect=None):
8def pushdown_predicates(expression, dialect=None): 9 """ 10 Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS 11 12 Example: 13 >>> import sqlglot 14 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" 15 >>> expression = sqlglot.parse_one(sql) 16 >>> pushdown_predicates(expression).sql() 17 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' 18 19 Args: 20 expression (sqlglot.Expression): expression to optimize 21 Returns: 22 sqlglot.Expression: optimized expression 23 """ 24 root = build_scope(expression) 25 26 if root: 27 scope_ref_count = root.ref_count() 28 29 for scope in reversed(list(root.traverse())): 30 select = scope.expression 31 where = select.args.get("where") 32 if where: 33 selected_sources = scope.selected_sources 34 join_index = { 35 join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) 36 } 37 38 # a right join can only push down to itself and not the source FROM table 39 for k, (node, source) in selected_sources.items(): 40 parent = node.find_ancestor(exp.Join, exp.From) 41 if isinstance(parent, exp.Join) and parent.side == "RIGHT": 42 selected_sources = {k: (node, source)} 43 break 44 45 pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) 46 47 # joins should only pushdown into itself, not to other joins 48 # so we limit the selected sources to only itself 49 for join in select.args.get("joins") or []: 50 name = join.alias_or_name 51 if name in scope.selected_sources: 52 pushdown( 53 join.args.get("on"), 54 {name: scope.selected_sources[name]}, 55 scope_ref_count, 56 dialect, 57 ) 58 59 return expression
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" >>> expression = sqlglot.parse_one(sql) >>> pushdown_predicates(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
- expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
def
pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
62def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): 63 if not condition: 64 return 65 66 condition = condition.replace(simplify(condition, dialect=dialect)) 67 cnf_like = normalized(condition) or not normalized(condition, dnf=True) 68 69 predicates = list( 70 condition.flatten() 71 if isinstance(condition, exp.And if cnf_like else exp.Or) 72 else [condition] 73 ) 74 75 if cnf_like: 76 pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) 77 else: 78 pushdown_dnf(predicates, sources, scope_ref_count)
def
pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
81def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): 82 """ 83 If the predicates are in CNF like form, we can simply replace each block in the parent. 84 """ 85 join_index = join_index or {} 86 for predicate in predicates: 87 for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): 88 if isinstance(node, exp.Join): 89 name = node.alias_or_name 90 predicate_tables = exp.column_table_names(predicate, name) 91 92 # Don't push the predicate if it references tables that appear in later joins 93 this_index = join_index[name] 94 if all(join_index.get(table, -1) < this_index for table in predicate_tables): 95 predicate.replace(exp.true()) 96 node.on(predicate, copy=False) 97 break 98 if isinstance(node, exp.Select): 99 predicate.replace(exp.true()) 100 inner_predicate = replace_aliases(node, predicate) 101 if find_in_scope(inner_predicate, exp.AggFunc): 102 node.having(inner_predicate, copy=False) 103 else: 104 node.where(inner_predicate, copy=False)
If the predicates are in CNF like form, we can simply replace each block in the parent.
def
pushdown_dnf(predicates, sources, scope_ref_count):
107def pushdown_dnf(predicates, sources, scope_ref_count): 108 """ 109 If the predicates are in DNF form, we can only push down conditions that are in all blocks. 110 Additionally, we can't remove predicates from their original form. 111 """ 112 # find all the tables that can be pushdown too 113 # these are tables that are referenced in all blocks of a DNF 114 # (a.x AND b.x) OR (a.y AND c.y) 115 # only table a can be push down 116 pushdown_tables = set() 117 118 for a in predicates: 119 a_tables = exp.column_table_names(a) 120 121 for b in predicates: 122 a_tables &= exp.column_table_names(b) 123 124 pushdown_tables.update(a_tables) 125 126 conditions = {} 127 128 # pushdown all predicates to their respective nodes 129 for table in sorted(pushdown_tables): 130 for predicate in predicates: 131 nodes = nodes_for_predicate(predicate, sources, scope_ref_count) 132 133 if table not in nodes: 134 continue 135 136 conditions[table] = ( 137 exp.or_(conditions[table], predicate) if table in conditions else predicate 138 ) 139 140 for name, node in nodes.items(): 141 if name not in conditions: 142 continue 143 144 predicate = conditions[name] 145 146 if isinstance(node, exp.Join): 147 node.on(predicate, copy=False) 148 elif isinstance(node, exp.Select): 149 inner_predicate = replace_aliases(node, predicate) 150 if find_in_scope(inner_predicate, exp.AggFunc): 151 node.having(inner_predicate, copy=False) 152 else: 153 node.where(inner_predicate, copy=False)
If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.
def
nodes_for_predicate(predicate, sources, scope_ref_count):
156def nodes_for_predicate(predicate, sources, scope_ref_count): 157 nodes = {} 158 tables = exp.column_table_names(predicate) 159 where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) 160 161 for table in sorted(tables): 162 node, source = sources.get(table) or (None, None) 163 164 # if the predicate is in a where statement we can try to push it down 165 # we want to find the root join or from statement 166 if node and where_condition: 167 node = node.find_ancestor(exp.Join, exp.From) 168 169 # a node can reference a CTE which should be pushed down 170 if isinstance(node, exp.From) and not isinstance(source, exp.Table): 171 with_ = source.parent.expression.args.get("with") 172 if with_ and with_.recursive: 173 return {} 174 node = source.expression 175 176 if isinstance(node, exp.Join): 177 if node.side and node.side != "RIGHT": 178 return {} 179 nodes[table] = node 180 elif isinstance(node, exp.Select) and len(tables) == 1: 181 # We can't push down window expressions 182 has_window_expression = any( 183 select for select in node.selects if select.find(exp.Window) 184 ) 185 # we can't push down predicates to select statements if they are referenced in 186 # multiple places. 187 if ( 188 not node.args.get("group") 189 and scope_ref_count[id(source)] < 2 190 and not has_window_expression 191 ): 192 nodes[table] = node 193 return nodes
def
replace_aliases(source, predicate):
196def replace_aliases(source, predicate): 197 aliases = {} 198 199 for select in source.selects: 200 if isinstance(select, exp.Alias): 201 aliases[select.alias] = select.this 202 else: 203 aliases[select.name] = select 204 205 def _replace_alias(column): 206 if isinstance(column, exp.Column) and column.name in aliases: 207 return aliases[column.name].copy() 208 return column 209 210 return predicate.transform(_replace_alias)