Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope, find_in_scope
  4from sqlglot.optimizer.simplify import simplify
  5from sqlglot import Dialect
  6
  7
  8def pushdown_predicates(expression, dialect=None):
  9    """
 10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
 15        >>> expression = sqlglot.parse_one(sql)
 16        >>> pushdown_predicates(expression).sql()
 17        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
 18
 19    Args:
 20        expression (sqlglot.Expression): expression to optimize
 21    Returns:
 22        sqlglot.Expression: optimized expression
 23    """
 24    from sqlglot.dialects.athena import Athena
 25    from sqlglot.dialects.presto import Presto
 26
 27    root = build_scope(expression)
 28
 29    dialect = Dialect.get_or_raise(dialect)
 30    unnest_requires_cross_join = isinstance(dialect, (Athena, Presto))
 31
 32    if root:
 33        scope_ref_count = root.ref_count()
 34
 35        for scope in reversed(list(root.traverse())):
 36            select = scope.expression
 37            where = select.args.get("where")
 38            if where:
 39                selected_sources = scope.selected_sources
 40                join_index = {
 41                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
 42                }
 43
 44                # a right join can only push down to itself and not the source FROM table
 45                # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression
 46                pushdown_allowed = True
 47                for k, (node, source) in selected_sources.items():
 48                    parent = node.find_ancestor(exp.Join, exp.From)
 49                    if isinstance(parent, exp.Join):
 50                        if parent.side == "RIGHT":
 51                            selected_sources = {k: (node, source)}
 52                            break
 53                        if isinstance(node, exp.Unnest) and unnest_requires_cross_join:
 54                            pushdown_allowed = False
 55                            break
 56
 57                if pushdown_allowed:
 58                    pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
 59
 60            # joins should only pushdown into itself, not to other joins
 61            # so we limit the selected sources to only itself
 62            for join in select.args.get("joins") or []:
 63                name = join.alias_or_name
 64                if name in scope.selected_sources:
 65                    pushdown(
 66                        join.args.get("on"),
 67                        {name: scope.selected_sources[name]},
 68                        scope_ref_count,
 69                        dialect,
 70                    )
 71
 72    return expression
 73
 74
 75def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
 76    if not condition:
 77        return
 78
 79    condition = condition.replace(simplify(condition, dialect=dialect))
 80    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 81
 82    predicates = list(
 83        condition.flatten()
 84        if isinstance(condition, exp.And if cnf_like else exp.Or)
 85        else [condition]
 86    )
 87
 88    if cnf_like:
 89        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
 90    else:
 91        pushdown_dnf(predicates, sources, scope_ref_count)
 92
 93
 94def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 95    """
 96    If the predicates are in CNF like form, we can simply replace each block in the parent.
 97    """
 98    join_index = join_index or {}
 99    for predicate in predicates:
100        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
101            if isinstance(node, exp.Join):
102                name = node.alias_or_name
103                predicate_tables = exp.column_table_names(predicate, name)
104
105                # Don't push the predicate if it references tables that appear in later joins
106                this_index = join_index[name]
107                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
108                    predicate.replace(exp.true())
109                    node.on(predicate, copy=False)
110                    break
111            if isinstance(node, exp.Select):
112                predicate.replace(exp.true())
113                inner_predicate = replace_aliases(node, predicate)
114                if find_in_scope(inner_predicate, exp.AggFunc):
115                    node.having(inner_predicate, copy=False)
116                else:
117                    node.where(inner_predicate, copy=False)
118
119
120def pushdown_dnf(predicates, sources, scope_ref_count):
121    """
122    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
123    Additionally, we can't remove predicates from their original form.
124    """
125    # find all the tables that can be pushdown too
126    # these are tables that are referenced in all blocks of a DNF
127    # (a.x AND b.x) OR (a.y AND c.y)
128    # only table a can be push down
129    pushdown_tables = set()
130
131    for a in predicates:
132        a_tables = exp.column_table_names(a)
133
134        for b in predicates:
135            a_tables &= exp.column_table_names(b)
136
137        pushdown_tables.update(a_tables)
138
139    conditions = {}
140
141    # pushdown all predicates to their respective nodes
142    for table in sorted(pushdown_tables):
143        for predicate in predicates:
144            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
145
146            if table not in nodes:
147                continue
148
149            conditions[table] = (
150                exp.or_(conditions[table], predicate) if table in conditions else predicate
151            )
152
153        for name, node in nodes.items():
154            if name not in conditions:
155                continue
156
157            predicate = conditions[name]
158
159            if isinstance(node, exp.Join):
160                node.on(predicate, copy=False)
161            elif isinstance(node, exp.Select):
162                inner_predicate = replace_aliases(node, predicate)
163                if find_in_scope(inner_predicate, exp.AggFunc):
164                    node.having(inner_predicate, copy=False)
165                else:
166                    node.where(inner_predicate, copy=False)
167
168
169def nodes_for_predicate(predicate, sources, scope_ref_count):
170    nodes = {}
171    tables = exp.column_table_names(predicate)
172    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
173
174    for table in sorted(tables):
175        node, source = sources.get(table) or (None, None)
176
177        # if the predicate is in a where statement we can try to push it down
178        # we want to find the root join or from statement
179        if node and where_condition:
180            node = node.find_ancestor(exp.Join, exp.From)
181
182        # a node can reference a CTE which should be pushed down
183        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
184            with_ = source.parent.expression.args.get("with")
185            if with_ and with_.recursive:
186                return {}
187            node = source.expression
188
189        if isinstance(node, exp.Join):
190            if node.side and node.side != "RIGHT":
191                return {}
192            nodes[table] = node
193        elif isinstance(node, exp.Select) and len(tables) == 1:
194            # We can't push down window expressions
195            has_window_expression = any(
196                select for select in node.selects if select.find(exp.Window)
197            )
198            # we can't push down predicates to select statements if they are referenced in
199            # multiple places.
200            if (
201                not node.args.get("group")
202                and scope_ref_count[id(source)] < 2
203                and not has_window_expression
204            ):
205                nodes[table] = node
206    return nodes
207
208
209def replace_aliases(source, predicate):
210    aliases = {}
211
212    for select in source.selects:
213        if isinstance(select, exp.Alias):
214            aliases[select.alias] = select.this
215        else:
216            aliases[select.name] = select
217
218    def _replace_alias(column):
219        if isinstance(column, exp.Column) and column.name in aliases:
220            return aliases[column.name].copy()
221        return column
222
223    return predicate.transform(_replace_alias)
def pushdown_predicates(expression, dialect=None):
 9def pushdown_predicates(expression, dialect=None):
10    """
11    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
12
13    Example:
14        >>> import sqlglot
15        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
16        >>> expression = sqlglot.parse_one(sql)
17        >>> pushdown_predicates(expression).sql()
18        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
19
20    Args:
21        expression (sqlglot.Expression): expression to optimize
22    Returns:
23        sqlglot.Expression: optimized expression
24    """
25    from sqlglot.dialects.athena import Athena
26    from sqlglot.dialects.presto import Presto
27
28    root = build_scope(expression)
29
30    dialect = Dialect.get_or_raise(dialect)
31    unnest_requires_cross_join = isinstance(dialect, (Athena, Presto))
32
33    if root:
34        scope_ref_count = root.ref_count()
35
36        for scope in reversed(list(root.traverse())):
37            select = scope.expression
38            where = select.args.get("where")
39            if where:
40                selected_sources = scope.selected_sources
41                join_index = {
42                    join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
43                }
44
45                # a right join can only push down to itself and not the source FROM table
46                # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression
47                pushdown_allowed = True
48                for k, (node, source) in selected_sources.items():
49                    parent = node.find_ancestor(exp.Join, exp.From)
50                    if isinstance(parent, exp.Join):
51                        if parent.side == "RIGHT":
52                            selected_sources = {k: (node, source)}
53                            break
54                        if isinstance(node, exp.Unnest) and unnest_requires_cross_join:
55                            pushdown_allowed = False
56                            break
57
58                if pushdown_allowed:
59                    pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
60
61            # joins should only pushdown into itself, not to other joins
62            # so we limit the selected sources to only itself
63            for join in select.args.get("joins") or []:
64                name = join.alias_or_name
65                if name in scope.selected_sources:
66                    pushdown(
67                        join.args.get("on"),
68                        {name: scope.selected_sources[name]},
69                        scope_ref_count,
70                        dialect,
71                    )
72
73    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
76def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
77    if not condition:
78        return
79
80    condition = condition.replace(simplify(condition, dialect=dialect))
81    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
82
83    predicates = list(
84        condition.flatten()
85        if isinstance(condition, exp.And if cnf_like else exp.Or)
86        else [condition]
87    )
88
89    if cnf_like:
90        pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
91    else:
92        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 95def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
 96    """
 97    If the predicates are in CNF like form, we can simply replace each block in the parent.
 98    """
 99    join_index = join_index or {}
100    for predicate in predicates:
101        for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
102            if isinstance(node, exp.Join):
103                name = node.alias_or_name
104                predicate_tables = exp.column_table_names(predicate, name)
105
106                # Don't push the predicate if it references tables that appear in later joins
107                this_index = join_index[name]
108                if all(join_index.get(table, -1) < this_index for table in predicate_tables):
109                    predicate.replace(exp.true())
110                    node.on(predicate, copy=False)
111                    break
112            if isinstance(node, exp.Select):
113                predicate.replace(exp.true())
114                inner_predicate = replace_aliases(node, predicate)
115                if find_in_scope(inner_predicate, exp.AggFunc):
116                    node.having(inner_predicate, copy=False)
117                else:
118                    node.where(inner_predicate, copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, sources, scope_ref_count):
121def pushdown_dnf(predicates, sources, scope_ref_count):
122    """
123    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
124    Additionally, we can't remove predicates from their original form.
125    """
126    # find all the tables that can be pushdown too
127    # these are tables that are referenced in all blocks of a DNF
128    # (a.x AND b.x) OR (a.y AND c.y)
129    # only table a can be push down
130    pushdown_tables = set()
131
132    for a in predicates:
133        a_tables = exp.column_table_names(a)
134
135        for b in predicates:
136            a_tables &= exp.column_table_names(b)
137
138        pushdown_tables.update(a_tables)
139
140    conditions = {}
141
142    # pushdown all predicates to their respective nodes
143    for table in sorted(pushdown_tables):
144        for predicate in predicates:
145            nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
146
147            if table not in nodes:
148                continue
149
150            conditions[table] = (
151                exp.or_(conditions[table], predicate) if table in conditions else predicate
152            )
153
154        for name, node in nodes.items():
155            if name not in conditions:
156                continue
157
158            predicate = conditions[name]
159
160            if isinstance(node, exp.Join):
161                node.on(predicate, copy=False)
162            elif isinstance(node, exp.Select):
163                inner_predicate = replace_aliases(node, predicate)
164                if find_in_scope(inner_predicate, exp.AggFunc):
165                    node.having(inner_predicate, copy=False)
166                else:
167                    node.where(inner_predicate, copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
170def nodes_for_predicate(predicate, sources, scope_ref_count):
171    nodes = {}
172    tables = exp.column_table_names(predicate)
173    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
174
175    for table in sorted(tables):
176        node, source = sources.get(table) or (None, None)
177
178        # if the predicate is in a where statement we can try to push it down
179        # we want to find the root join or from statement
180        if node and where_condition:
181            node = node.find_ancestor(exp.Join, exp.From)
182
183        # a node can reference a CTE which should be pushed down
184        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
185            with_ = source.parent.expression.args.get("with")
186            if with_ and with_.recursive:
187                return {}
188            node = source.expression
189
190        if isinstance(node, exp.Join):
191            if node.side and node.side != "RIGHT":
192                return {}
193            nodes[table] = node
194        elif isinstance(node, exp.Select) and len(tables) == 1:
195            # We can't push down window expressions
196            has_window_expression = any(
197                select for select in node.selects if select.find(exp.Window)
198            )
199            # we can't push down predicates to select statements if they are referenced in
200            # multiple places.
201            if (
202                not node.args.get("group")
203                and scope_ref_count[id(source)] < 2
204                and not has_window_expression
205            ):
206                nodes[table] = node
207    return nodes
def replace_aliases(source, predicate):
210def replace_aliases(source, predicate):
211    aliases = {}
212
213    for select in source.selects:
214        if isinstance(select, exp.Alias):
215            aliases[select.alias] = select.this
216        else:
217            aliases[select.name] = select
218
219    def _replace_alias(column):
220        if isinstance(column, exp.Column) and column.name in aliases:
221            return aliases[column.name].copy()
222        return column
223
224    return predicate.transform(_replace_alias)