Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope, find_in_scope
  4from sqlglot.optimizer.simplify import simplify
  5from sqlglot import Dialect
  6
  7
  8def pushdown_predicates(expression, dialect=None):
  9    """
 10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
 15        >>> expression = sqlglot.parse_one(sql)
 16        >>> pushdown_predicates(expression).sql()
 17        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to optimize
 21    Returns:
 22        sqlglot.Expression: optimized expression
 23    """
 24    from sqlglot.dialects.presto import Presto
 25
 26    root = build_scope(expression)
 27
 28    dialect = Dialect.get_or_raise(dialect)
 29    unnest_requires_cross_join = isinstance(dialect, Presto)
 30
 31    if root:
 32        scope_ref_count = root.ref_count()
 33
 34        for scope in reversed(list(root.traverse())):
 35            select = scope.expression
 36            where = select.args.get("where")
 37            if where:
 38                selected_sources = scope.selected_sources
 39                join_index = {
 40                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
 41                }
 42
 43                # a right join can only push down to itself and not the source FROM table
 44                # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression
 45                pushdown_allowed = True
 46                for k, (node, source) in selected_sources.items():
 47                    parent = node.find_ancestor(exp.Join, exp.From)
 48                    if isinstance(parent, exp.Join):
 49                        if parent.side == "RIGHT":
 50                            selected_sources = {k: (node, source)}
 51                            break
 52                        if isinstance(node, exp.Unnest) and unnest_requires_cross_join:
 53                            pushdown_allowed = False
 54                            break
 55
 56                if pushdown_allowed:
 57                    pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
 58
 59            # joins should only pushdown into itself, not to other joins
 60            # so we limit the selected sources to only itself
 61            for join in select.args.get("joins") or []:
 62                name = join.alias_or_name
 63                if name in scope.selected_sources:
 64                    pushdown(
 65                        join.args.get("on"),
 66                        {name: scope.selected_sources[name]},
 67                        scope_ref_count,
 68                        dialect,
 69                    )
 70
 71    return expression
 72
 73
 74def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
 75    if not condition:
 76        return
 77
 78    condition = condition.replace(simplify(condition, dialect=dialect))
 79    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 80
 81    predicates = list(
 82        condition.flatten()
 83        if isinstance(condition, exp.And if cnf_like else exp.Or)
 84        else [condition]
 85    )
 86
 87    if cnf_like:
 88        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
 89    else:
 90        pushdown_dnf(predicates, sources, scope_ref_count)
 91
 92
 93def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 94    """
 95    If the predicates are in CNF like form, we can simply replace each block in the parent.
 96    """
 97    join_index = join_index or {}
 98    for predicate in predicates:
 99        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
100            if isinstance(node, exp.Join):
101                name = node.alias_or_name
102                predicate_tables = exp.column_table_names(predicate, name)
103
104                # Don't push the predicate if it references tables that appear in later joins
105                this_index = join_index[name]
106                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
107                    predicate.replace(exp.true())
108                    node.on(predicate, copy=False)
109                    break
110            if isinstance(node, exp.Select):
111                predicate.replace(exp.true())
112                inner_predicate = replace_aliases(node, predicate)
113                if find_in_scope(inner_predicate, exp.AggFunc):
114                    node.having(inner_predicate, copy=False)
115                else:
116                    node.where(inner_predicate, copy=False)
117
118
119def pushdown_dnf(predicates, sources, scope_ref_count):
120    """
121    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
122    Additionally, we can't remove predicates from their original form.
123    """
124    # find all the tables that can be pushdown too
125    # these are tables that are referenced in all blocks of a DNF
126    # (a.x AND b.x) OR (a.y AND c.y)
127    # only table a can be push down
128    pushdown_tables = set()
129
130    for a in predicates:
131        a_tables = exp.column_table_names(a)
132
133        for b in predicates:
134            a_tables &= exp.column_table_names(b)
135
136        pushdown_tables.update(a_tables)
137
138    conditions = {}
139
140    # pushdown all predicates to their respective nodes
141    for table in sorted(pushdown_tables):
142        for predicate in predicates:
143            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
144
145            if table not in nodes:
146                continue
147
148            conditions[table] = (
149                exp.or_(conditions[table], predicate) if table in conditions else predicate
150            )
151
152        for name, node in nodes.items():
153            if name not in conditions:
154                continue
155
156            predicate = conditions[name]
157
158            if isinstance(node, exp.Join):
159                node.on(predicate, copy=False)
160            elif isinstance(node, exp.Select):
161                inner_predicate = replace_aliases(node, predicate)
162                if find_in_scope(inner_predicate, exp.AggFunc):
163                    node.having(inner_predicate, copy=False)
164                else:
165                    node.where(inner_predicate, copy=False)
166
167
168def nodes_for_predicate(predicate, sources, scope_ref_count):
169    nodes = {}
170    tables = exp.column_table_names(predicate)
171    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
172
173    for table in sorted(tables):
174        node, source = sources.get(table) or (None, None)
175
176        # if the predicate is in a where statement we can try to push it down
177        # we want to find the root join or from statement
178        if node and where_condition:
179            node = node.find_ancestor(exp.Join, exp.From)
180
181        # a node can reference a CTE which should be pushed down
182        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
183            with_ = source.parent.expression.args.get("with")
184            if with_ and with_.recursive:
185                return {}
186            node = source.expression
187
188        if isinstance(node, exp.Join):
189            if node.side and node.side != "RIGHT":
190                return {}
191            nodes[table] = node
192        elif isinstance(node, exp.Select) and len(tables) == 1:
193            # We can't push down window expressions
194            has_window_expression = any(
195                select for select in node.selects if select.find(exp.Window)
196            )
197            # we can't push down predicates to select statements if they are referenced in
198            # multiple places.
199            if (
200                not node.args.get("group")
201                and scope_ref_count[id(source)] < 2
202                and not has_window_expression
203            ):
204                nodes[table] = node
205    return nodes
206
207
208def replace_aliases(source, predicate):
209    aliases = {}
210
211    for select in source.selects:
212        if isinstance(select, exp.Alias):
213            aliases[select.alias] = select.this
214        else:
215            aliases[select.name] = select
216
217    def _replace_alias(column):
218        if isinstance(column, exp.Column) and column.name in aliases:
219            return aliases[column.name].copy()
220        return column
221
222    return predicate.transform(_replace_alias)
def pushdown_predicates(expression, dialect=None):
 9def pushdown_predicates(expression, dialect=None):
10    """
11    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
12
13    Example:
14        >>> import sqlglot
15        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
16        >>> expression = sqlglot.parse_one(sql)
17        >>> pushdown_predicates(expression).sql()
18        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
19
20    Args:
21        expression (sqlglot.Expression): expression to optimize
22    Returns:
23        sqlglot.Expression: optimized expression
24    """
25    from sqlglot.dialects.presto import Presto
26
27    root = build_scope(expression)
28
29    dialect = Dialect.get_or_raise(dialect)
30    unnest_requires_cross_join = isinstance(dialect, Presto)
31
32    if root:
33        scope_ref_count = root.ref_count()
34
35        for scope in reversed(list(root.traverse())):
36            select = scope.expression
37            where = select.args.get("where")
38            if where:
39                selected_sources = scope.selected_sources
40                join_index = {
41                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
42                }
43
44                # a right join can only push down to itself and not the source FROM table
45                # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression
46                pushdown_allowed = True
47                for k, (node, source) in selected_sources.items():
48                    parent = node.find_ancestor(exp.Join, exp.From)
49                    if isinstance(parent, exp.Join):
50                        if parent.side == "RIGHT":
51                            selected_sources = {k: (node, source)}
52                            break
53                        if isinstance(node, exp.Unnest) and unnest_requires_cross_join:
54                            pushdown_allowed = False
55                            break
56
57                if pushdown_allowed:
58                    pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
59
60            # joins should only pushdown into itself, not to other joins
61            # so we limit the selected sources to only itself
62            for join in select.args.get("joins") or []:
63                name = join.alias_or_name
64                if name in scope.selected_sources:
65                    pushdown(
66                        join.args.get("on"),
67                        {name: scope.selected_sources[name]},
68                        scope_ref_count,
69                        dialect,
70                    )
71
72    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
75def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
76    if not condition:
77        return
78
79    condition = condition.replace(simplify(condition, dialect=dialect))
80    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
81
82    predicates = list(
83        condition.flatten()
84        if isinstance(condition, exp.And if cnf_like else exp.Or)
85        else [condition]
86    )
87
88    if cnf_like:
89        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
90    else:
91        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 94def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 95    """
 96    If the predicates are in CNF like form, we can simply replace each block in the parent.
 97    """
 98    join_index = join_index or {}
 99    for predicate in predicates:
100        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
101            if isinstance(node, exp.Join):
102                name = node.alias_or_name
103                predicate_tables = exp.column_table_names(predicate, name)
104
105                # Don't push the predicate if it references tables that appear in later joins
106                this_index = join_index[name]
107                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
108                    predicate.replace(exp.true())
109                    node.on(predicate, copy=False)
110                    break
111            if isinstance(node, exp.Select):
112                predicate.replace(exp.true())
113                inner_predicate = replace_aliases(node, predicate)
114                if find_in_scope(inner_predicate, exp.AggFunc):
115                    node.having(inner_predicate, copy=False)
116                else:
117                    node.where(inner_predicate, copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, sources, scope_ref_count):
120def pushdown_dnf(predicates, sources, scope_ref_count):
121    """
122    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
123    Additionally, we can't remove predicates from their original form.
124    """
125    # find all the tables that can be pushdown too
126    # these are tables that are referenced in all blocks of a DNF
127    # (a.x AND b.x) OR (a.y AND c.y)
128    # only table a can be push down
129    pushdown_tables = set()
130
131    for a in predicates:
132        a_tables = exp.column_table_names(a)
133
134        for b in predicates:
135            a_tables &= exp.column_table_names(b)
136
137        pushdown_tables.update(a_tables)
138
139    conditions = {}
140
141    # pushdown all predicates to their respective nodes
142    for table in sorted(pushdown_tables):
143        for predicate in predicates:
144            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
145
146            if table not in nodes:
147                continue
148
149            conditions[table] = (
150                exp.or_(conditions[table], predicate) if table in conditions else predicate
151            )
152
153        for name, node in nodes.items():
154            if name not in conditions:
155                continue
156
157            predicate = conditions[name]
158
159            if isinstance(node, exp.Join):
160                node.on(predicate, copy=False)
161            elif isinstance(node, exp.Select):
162                inner_predicate = replace_aliases(node, predicate)
163                if find_in_scope(inner_predicate, exp.AggFunc):
164                    node.having(inner_predicate, copy=False)
165                else:
166                    node.where(inner_predicate, copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
169def nodes_for_predicate(predicate, sources, scope_ref_count):
170    nodes = {}
171    tables = exp.column_table_names(predicate)
172    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
173
174    for table in sorted(tables):
175        node, source = sources.get(table) or (None, None)
176
177        # if the predicate is in a where statement we can try to push it down
178        # we want to find the root join or from statement
179        if node and where_condition:
180            node = node.find_ancestor(exp.Join, exp.From)
181
182        # a node can reference a CTE which should be pushed down
183        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
184            with_ = source.parent.expression.args.get("with")
185            if with_ and with_.recursive:
186                return {}
187            node = source.expression
188
189        if isinstance(node, exp.Join):
190            if node.side and node.side != "RIGHT":
191                return {}
192            nodes[table] = node
193        elif isinstance(node, exp.Select) and len(tables) == 1:
194            # We can't push down window expressions
195            has_window_expression = any(
196                select for select in node.selects if select.find(exp.Window)
197            )
198            # we can't push down predicates to select statements if they are referenced in
199            # multiple places.
200            if (
201                not node.args.get("group")
202                and scope_ref_count[id(source)] < 2
203                and not has_window_expression
204            ):
205                nodes[table] = node
206    return nodes
def replace_aliases(source, predicate):
209def replace_aliases(source, predicate):
210    aliases = {}
211
212    for select in source.selects:
213        if isinstance(select, exp.Alias):
214            aliases[select.alias] = select.this
215        else:
216            aliases[select.name] = select
217
218    def _replace_alias(column):
219        if isinstance(column, exp.Column) and column.name in aliases:
220            return aliases[column.name].copy()
221        return column
222
223    return predicate.transform(_replace_alias)