Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot.optimizer.normalize import normalized
  7from sqlglot.optimizer.scope import build_scope, find_in_scope
  8from sqlglot.optimizer.simplify import simplify
  9from sqlglot import Dialect
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot._typing import E
 13    from sqlglot.dialects.dialect import DialectType
 14
 15
 16def pushdown_predicates(expression: E, dialect: DialectType = None) -> E:
 17    """
 18    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 19
 20    Example:
 21        >>> import sqlglot
 22        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
 23        >>> expression = sqlglot.parse_one(sql)
 24        >>> pushdown_predicates(expression).sql()
 25        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
 26
 27    Args:
 28        expression (sqlglot.Expression): expression to optimize
 29    Returns:
 30        sqlglot.Expression: optimized expression
 31    """
 32    from sqlglot.dialects.athena import Athena
 33    from sqlglot.dialects.presto import Presto
 34
 35    root = build_scope(expression)
 36
 37    dialect = Dialect.get_or_raise(dialect)
 38    unnest_requires_cross_join = isinstance(dialect, (Athena, Presto))
 39
 40    if root:
 41        scope_ref_count = root.ref_count()
 42
 43        for scope in reversed(list(root.traverse())):
 44            select = scope.expression
 45            where = select.args.get("where")
 46            if where:
 47                selected_sources = scope.selected_sources
 48                join_index = {
 49                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
 50                }
 51
 52                # a right join can only push down to itself and not the source FROM table
 53                # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression
 54                pushdown_allowed = True
 55                for k, (node, source) in selected_sources.items():
 56                    parent = node.find_ancestor(exp.Join, exp.From)
 57                    if isinstance(parent, exp.Join):
 58                        if parent.side == "RIGHT":
 59                            selected_sources = {k: (node, source)}
 60                            break
 61                        if isinstance(node, exp.Unnest) and unnest_requires_cross_join:
 62                            pushdown_allowed = False
 63                            break
 64
 65                if pushdown_allowed:
 66                    pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
 67
 68            # joins should only pushdown into itself, not to other joins
 69            # so we limit the selected sources to only itself
 70            for join in select.args.get("joins") or []:
 71                name = join.alias_or_name
 72                if name in scope.selected_sources:
 73                    pushdown(
 74                        join.args.get("on"),
 75                        {name: scope.selected_sources[name]},
 76                        scope_ref_count,
 77                        dialect,
 78                    )
 79
 80    return expression
 81
 82
 83def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
 84    if not condition:
 85        return
 86
 87    condition = condition.replace(simplify(condition, dialect=dialect))
 88    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 89
 90    predicates = list(
 91        condition.flatten()
 92        if isinstance(condition, exp.And if cnf_like else exp.Or)
 93        else [condition]
 94    )
 95
 96    if cnf_like:
 97        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
 98    else:
 99        pushdown_dnf(predicates, sources, scope_ref_count)
100
101
102def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
103    """
104    If the predicates are in CNF like form, we can simply replace each block in the parent.
105    """
106    join_index = join_index or {}
107    for predicate in predicates:
108        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
109            if isinstance(node, exp.Join):
110                name = node.alias_or_name
111                predicate_tables = exp.column_table_names(predicate, name)
112
113                # Don't push the predicate if it references tables that appear in later joins
114                this_index = join_index[name]
115                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
116                    predicate.replace(exp.true())
117                    node.on(predicate, copy=False)
118                    break
119            if isinstance(node, exp.Select):
120                predicate.replace(exp.true())
121                inner_predicate = replace_aliases(node, predicate)
122                if find_in_scope(inner_predicate, exp.AggFunc):
123                    node.having(inner_predicate, copy=False)
124                else:
125                    node.where(inner_predicate, copy=False)
126
127
128def pushdown_dnf(predicates, sources, scope_ref_count):
129    """
130    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
131    Additionally, we can't remove predicates from their original form.
132    """
133    # find all the tables that can be pushdown too
134    # these are tables that are referenced in all blocks of a DNF
135    # (a.x AND b.x) OR (a.y AND c.y)
136    # only table a can be push down
137    pushdown_tables = set()
138
139    for a in predicates:
140        a_tables = exp.column_table_names(a)
141
142        for b in predicates:
143            a_tables &= exp.column_table_names(b)
144
145        pushdown_tables.update(a_tables)
146
147    conditions = {}
148
149    # pushdown all predicates to their respective nodes
150    for table in sorted(pushdown_tables):
151        for predicate in predicates:
152            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
153
154            if table not in nodes:
155                continue
156
157            conditions[table] = (
158                exp.or_(conditions[table], predicate) if table in conditions else predicate
159            )
160
161        for name, node in nodes.items():
162            if name not in conditions:
163                continue
164
165            predicate = conditions[name]
166
167            if isinstance(node, exp.Join):
168                node.on(predicate, copy=False)
169            elif isinstance(node, exp.Select):
170                inner_predicate = replace_aliases(node, predicate)
171                if find_in_scope(inner_predicate, exp.AggFunc):
172                    node.having(inner_predicate, copy=False)
173                else:
174                    node.where(inner_predicate, copy=False)
175
176
177def nodes_for_predicate(predicate, sources, scope_ref_count):
178    nodes = {}
179    tables = exp.column_table_names(predicate)
180    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
181
182    for table in sorted(tables):
183        node, source = sources.get(table) or (None, None)
184
185        # if the predicate is in a where statement we can try to push it down
186        # we want to find the root join or from statement
187        if node and where_condition:
188            node = node.find_ancestor(exp.Join, exp.From)
189
190        # a node can reference a CTE which should be pushed down
191        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
192            with_ = source.parent.expression.args.get("with_")
193            if with_ and with_.recursive:
194                return {}
195            node = source.expression
196
197        if isinstance(node, exp.Join):
198            if node.side and node.side != "RIGHT":
199                return {}
200            nodes[table] = node
201        elif isinstance(node, exp.Select) and len(tables) == 1:
202            # We can't push down window expressions
203            has_window_expression = any(
204                select for select in node.selects if select.find(exp.Window)
205            )
206            # we can't push down predicates to select statements if they are referenced in
207            # multiple places.
208            if (
209                not node.args.get("group")
210                and scope_ref_count[id(source)] < 2
211                and not has_window_expression
212            ):
213                nodes[table] = node
214    return nodes
215
216
217def replace_aliases(source, predicate):
218    aliases = {}
219
220    for select in source.selects:
221        if isinstance(select, exp.Alias):
222            aliases[select.alias] = select.this
223        else:
224            aliases[select.name] = select
225
226    def _replace_alias(column):
227        if isinstance(column, exp.Column) and column.name in aliases:
228            return aliases[column.name].copy()
229        return column
230
231    return predicate.transform(_replace_alias)
def pushdown_predicates( expression: ~E, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
17def pushdown_predicates(expression: E, dialect: DialectType = None) -> E:
18    """
19    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
20
21    Example:
22        >>> import sqlglot
23        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
24        >>> expression = sqlglot.parse_one(sql)
25        >>> pushdown_predicates(expression).sql()
26        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
27
28    Args:
29        expression (sqlglot.Expression): expression to optimize
30    Returns:
31        sqlglot.Expression: optimized expression
32    """
33    from sqlglot.dialects.athena import Athena
34    from sqlglot.dialects.presto import Presto
35
36    root = build_scope(expression)
37
38    dialect = Dialect.get_or_raise(dialect)
39    unnest_requires_cross_join = isinstance(dialect, (Athena, Presto))
40
41    if root:
42        scope_ref_count = root.ref_count()
43
44        for scope in reversed(list(root.traverse())):
45            select = scope.expression
46            where = select.args.get("where")
47            if where:
48                selected_sources = scope.selected_sources
49                join_index = {
50                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
51                }
52
53                # a right join can only push down to itself and not the source FROM table
54                # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression
55                pushdown_allowed = True
56                for k, (node, source) in selected_sources.items():
57                    parent = node.find_ancestor(exp.Join, exp.From)
58                    if isinstance(parent, exp.Join):
59                        if parent.side == "RIGHT":
60                            selected_sources = {k: (node, source)}
61                            break
62                        if isinstance(node, exp.Unnest) and unnest_requires_cross_join:
63                            pushdown_allowed = False
64                            break
65
66                if pushdown_allowed:
67                    pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
68
69            # joins should only pushdown into itself, not to other joins
70            # so we limit the selected sources to only itself
71            for join in select.args.get("joins") or []:
72                name = join.alias_or_name
73                if name in scope.selected_sources:
74                    pushdown(
75                        join.args.get("on"),
76                        {name: scope.selected_sources[name]},
77                        scope_ref_count,
78                        dialect,
79                    )
80
81    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
 84def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
 85    if not condition:
 86        return
 87
 88    condition = condition.replace(simplify(condition, dialect=dialect))
 89    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 90
 91    predicates = list(
 92        condition.flatten()
 93        if isinstance(condition, exp.And if cnf_like else exp.Or)
 94        else [condition]
 95    )
 96
 97    if cnf_like:
 98        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
 99    else:
100        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
103def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
104    """
105    If the predicates are in CNF like form, we can simply replace each block in the parent.
106    """
107    join_index = join_index or {}
108    for predicate in predicates:
109        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
110            if isinstance(node, exp.Join):
111                name = node.alias_or_name
112                predicate_tables = exp.column_table_names(predicate, name)
113
114                # Don't push the predicate if it references tables that appear in later joins
115                this_index = join_index[name]
116                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
117                    predicate.replace(exp.true())
118                    node.on(predicate, copy=False)
119                    break
120            if isinstance(node, exp.Select):
121                predicate.replace(exp.true())
122                inner_predicate = replace_aliases(node, predicate)
123                if find_in_scope(inner_predicate, exp.AggFunc):
124                    node.having(inner_predicate, copy=False)
125                else:
126                    node.where(inner_predicate, copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, sources, scope_ref_count):
129def pushdown_dnf(predicates, sources, scope_ref_count):
130    """
131    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
132    Additionally, we can't remove predicates from their original form.
133    """
134    # find all the tables that can be pushdown too
135    # these are tables that are referenced in all blocks of a DNF
136    # (a.x AND b.x) OR (a.y AND c.y)
137    # only table a can be push down
138    pushdown_tables = set()
139
140    for a in predicates:
141        a_tables = exp.column_table_names(a)
142
143        for b in predicates:
144            a_tables &= exp.column_table_names(b)
145
146        pushdown_tables.update(a_tables)
147
148    conditions = {}
149
150    # pushdown all predicates to their respective nodes
151    for table in sorted(pushdown_tables):
152        for predicate in predicates:
153            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
154
155            if table not in nodes:
156                continue
157
158            conditions[table] = (
159                exp.or_(conditions[table], predicate) if table in conditions else predicate
160            )
161
162        for name, node in nodes.items():
163            if name not in conditions:
164                continue
165
166            predicate = conditions[name]
167
168            if isinstance(node, exp.Join):
169                node.on(predicate, copy=False)
170            elif isinstance(node, exp.Select):
171                inner_predicate = replace_aliases(node, predicate)
172                if find_in_scope(inner_predicate, exp.AggFunc):
173                    node.having(inner_predicate, copy=False)
174                else:
175                    node.where(inner_predicate, copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
178def nodes_for_predicate(predicate, sources, scope_ref_count):
179    nodes = {}
180    tables = exp.column_table_names(predicate)
181    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
182
183    for table in sorted(tables):
184        node, source = sources.get(table) or (None, None)
185
186        # if the predicate is in a where statement we can try to push it down
187        # we want to find the root join or from statement
188        if node and where_condition:
189            node = node.find_ancestor(exp.Join, exp.From)
190
191        # a node can reference a CTE which should be pushed down
192        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
193            with_ = source.parent.expression.args.get("with_")
194            if with_ and with_.recursive:
195                return {}
196            node = source.expression
197
198        if isinstance(node, exp.Join):
199            if node.side and node.side != "RIGHT":
200                return {}
201            nodes[table] = node
202        elif isinstance(node, exp.Select) and len(tables) == 1:
203            # We can't push down window expressions
204            has_window_expression = any(
205                select for select in node.selects if select.find(exp.Window)
206            )
207            # we can't push down predicates to select statements if they are referenced in
208            # multiple places.
209            if (
210                not node.args.get("group")
211                and scope_ref_count[id(source)] < 2
212                and not has_window_expression
213            ):
214                nodes[table] = node
215    return nodes
def replace_aliases(source, predicate):
218def replace_aliases(source, predicate):
219    aliases = {}
220
221    for select in source.selects:
222        if isinstance(select, exp.Alias):
223            aliases[select.alias] = select.this
224        else:
225            aliases[select.name] = select
226
227    def _replace_alias(column):
228        if isinstance(column, exp.Column) and column.name in aliases:
229            return aliases[column.name].copy()
230        return column
231
232    return predicate.transform(_replace_alias)