Edit on GitHub

sqlglot.planner

  1from __future__ import annotations
  2
  3import math
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.helper import name_sequence
  8from sqlglot.optimizer.eliminate_joins import join_condition
  9from collections.abc import Iterator, Sequence, Iterable
 10
 11
 12class Plan:
 13    def __init__(self, expression: exp.Expr) -> None:
 14        self.expression = expression.copy()
 15        self.root = Step.from_expression(self.expression)
 16        self._dag: dict[Step, set[Step]] = {}
 17
 18    @property
 19    def dag(self) -> dict[Step, set[Step]]:
 20        if not self._dag:
 21            dag: dict[Step, set[Step]] = {}
 22            nodes = {self.root}
 23
 24            while nodes:
 25                node = nodes.pop()
 26                dag[node] = set()
 27
 28                for dep in node.dependencies:
 29                    dag[node].add(dep)
 30                    nodes.add(dep)
 31
 32            self._dag = dag
 33
 34        return self._dag
 35
 36    @property
 37    def leaves(self) -> Iterator[Step]:
 38        return (node for node, deps in self.dag.items() if not deps)
 39
 40    def __repr__(self) -> str:
 41        return f"Plan\n----\n{repr(self.root)}"
 42
 43
 44class Step:
 45    @classmethod
 46    def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
 47        """
 48        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 49        Note: the expression's tables and subqueries must be aliased for this method to work. For
 50        example, given the following expression:
 51
 52        SELECT
 53          x.a,
 54          SUM(x.b)
 55        FROM x AS x
 56        JOIN y AS y
 57          ON x.a = y.a
 58        GROUP BY x.a
 59
 60        the following DAG is produced (the expression IDs might differ per execution):
 61
 62        - Aggregate: x (4347984624)
 63            Context:
 64              Aggregations:
 65                - SUM(x.b)
 66              Group:
 67                - x.a
 68            Projections:
 69              - x.a
 70              - "x".""
 71            Dependencies:
 72            - Join: x (4347985296)
 73              Context:
 74                y:
 75                On: x.a = y.a
 76              Projections:
 77              Dependencies:
 78              - Scan: x (4347983136)
 79                Context:
 80                  Source: x AS x
 81                Projections:
 82              - Scan: y (4343416624)
 83                Context:
 84                  Source: y AS y
 85                Projections:
 86
 87        Args:
 88            expression: the expression to build the DAG from.
 89            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 90
 91        Returns:
 92            A Step DAG corresponding to `expression`.
 93        """
 94        ctes = ctes or {}
 95        expression = expression.unnest()
 96        with_ = expression.args.get("with_")
 97
 98        # CTEs break the mold of scope and introduce themselves to all in the context.
 99        if with_:
100            ctes = ctes.copy()
101            for cte in with_.expressions:
102                step = Step.from_expression(cte.this, ctes)
103                step.name = cte.alias
104                ctes[step.name] = step  # type: ignore
105
106        from_ = expression.args.get("from_")
107
108        if isinstance(expression, exp.Select) and from_:
109            step = Scan.from_expression(from_.this, ctes)
110        elif isinstance(expression, exp.SetOperation):
111            step = SetOperation.from_expression(expression, ctes)
112        else:
113            step = Scan()
114
115        joins = expression.args.get("joins")
116
117        if joins:
118            join = Join.from_joins(joins, ctes)
119            join.name = step.name
120            join.source_name = step.name
121            join.add_dependency(step)
122            step = join
123
124        projections: list[
125            exp.Expr
126        ] = []  # final selects in this chain of steps representing a select
127        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
128        aggregations = {}
129        next_operand_name = name_sequence("_a_")
130
131        def extract_agg_operands(expression):
132            agg_funcs = tuple(expression.find_all(exp.AggFunc))
133            if agg_funcs:
134                aggregations[expression] = None
135
136            for agg in agg_funcs:
137                for operand in agg.unnest_operands():
138                    if isinstance(operand, exp.Column):
139                        continue
140                    if operand not in operands:
141                        operands[operand] = next_operand_name()
142
143                    operand.replace(exp.column(operands[operand], quoted=True))
144
145            return bool(agg_funcs)
146
147        def set_ops_and_aggs(step):
148            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
149            step.aggregations = list(aggregations)
150
151        for e in expression.expressions:
152            if e.find(exp.AggFunc):
153                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
154                extract_agg_operands(e)
155            else:
156                projections.append(e)
157
158        where = expression.args.get("where")
159
160        if where:
161            step.condition = where.this
162
163        group = expression.args.get("group")
164
165        if group or aggregations:
166            aggregate = Aggregate()
167            aggregate.source = step.name
168            aggregate.name = step.name
169
170            having = expression.args.get("having")
171
172            if having:
173                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
174                    aggregate.condition = exp.column("_h", step.name, quoted=True)
175                else:
176                    aggregate.condition = having.this
177
178            set_ops_and_aggs(aggregate)
179
180            # give aggregates names and replace projections with references to them
181            aggregate.group = {
182                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
183            }
184
185            intermediate: dict[str | exp.Expr, str] = {}
186            for k, v in aggregate.group.items():
187                intermediate[v] = k
188                if isinstance(v, exp.Column):
189                    intermediate[v.name] = k
190
191            for projection in projections:
192                for node in projection.walk():
193                    name = intermediate.get(node)
194                    if name:
195                        node.replace(exp.column(name, step.name))
196
197            if aggregate.condition:
198                for node in aggregate.condition.walk():
199                    name = intermediate.get(node) or intermediate.get(node.name)
200                    if name:
201                        node.replace(exp.column(name, step.name))
202
203            aggregate.add_dependency(step)
204            step = aggregate
205        else:
206            aggregate = None
207
208        order = expression.args.get("order")
209
210        if order:
211            if aggregate and isinstance(step, Aggregate):
212                for i, ordered in enumerate(order.expressions):
213                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
214                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
215
216                set_ops_and_aggs(aggregate)
217
218            sort = Sort()
219            sort.name = step.name
220            sort.key = order.expressions
221            sort.add_dependency(step)
222            step = sort
223
224        step.projections = projections
225
226        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
227            distinct = Aggregate()
228            distinct.source = step.name
229            distinct.name = step.name
230            distinct.group = {
231                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
232                for e in projections or expression.expressions
233            }
234            distinct.add_dependency(step)
235            step = distinct
236
237        limit = expression.args.get("limit")
238
239        if limit:
240            step.limit = int(limit.text("expression"))
241
242        return step
243
244    def __init__(self) -> None:
245        self.name: str | None = None
246        self.dependencies: set[Step] = set()
247        self.dependents: set[Step] = set()
248        self.projections: Sequence[exp.Expr] = []
249        self.limit: float = math.inf
250        self.condition: exp.Expr | None = None
251
252    def add_dependency(self, dependency: Step) -> None:
253        self.dependencies.add(dependency)
254        dependency.dependents.add(self)
255
256    def __repr__(self) -> str:
257        return self.to_s()
258
259    def to_s(self, level: int = 0) -> str:
260        indent = "  " * level
261        nested = f"{indent}    "
262
263        context = self._to_s(f"{nested}  ")
264
265        if context:
266            context = [f"{nested}Context:"] + context
267
268        lines = [
269            f"{indent}- {self.id}",
270            *context,
271            f"{nested}Projections:",
272        ]
273
274        for expression in self.projections:
275            lines.append(f"{nested}  - {expression.sql()}")
276
277        if self.condition:
278            lines.append(f"{nested}Condition: {self.condition.sql()}")
279
280        if self.limit is not math.inf:
281            lines.append(f"{nested}Limit: {self.limit}")
282
283        if self.dependencies:
284            lines.append(f"{nested}Dependencies:")
285            for dependency in self.dependencies:
286                lines.append("  " + dependency.to_s(level + 1))
287
288        return "\n".join(lines)
289
290    @property
291    def type_name(self) -> str:
292        return self.__class__.__name__
293
294    @property
295    def id(self) -> str:
296        name = self.name
297        name = f" {name}" if name else ""
298        return f"{self.type_name}:{name} ({id(self)})"
299
300    def _to_s(self, _indent: str) -> list[str]:
301        return []
302
303
304class Scan(Step):
305    @classmethod
306    def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
307        table = expression
308        alias_ = expression.alias_or_name
309
310        if isinstance(expression, exp.Subquery):
311            table = expression.this
312            step = Step.from_expression(table, ctes)
313            step.name = alias_
314            return step
315
316        step = Scan()
317        step.name = alias_
318        step.source = expression
319        if ctes and table.name in ctes:
320            step.add_dependency(ctes[table.name])
321
322        return step
323
324    def __init__(self) -> None:
325        super().__init__()
326        self.source: exp.Expr | None = None
327
328    def _to_s(self, indent: str) -> list[str]:
329        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
330
331
332class Join(Step):
333    @classmethod
334    def from_joins(cls, joins: Iterable[exp.Join], ctes: dict[str, Step] | None = None) -> Join:
335        step = Join()
336
337        for join in joins:
338            source_key, join_key, condition = join_condition(join)
339            step.joins[join.alias_or_name] = {
340                "side": join.side,  # type: ignore
341                "join_key": join_key,
342                "source_key": source_key,
343                "condition": condition,
344            }
345
346            step.add_dependency(Scan.from_expression(join.this, ctes))
347
348        return step
349
350    def __init__(self) -> None:
351        super().__init__()
352        self.source_name: str | None = None
353        self.joins: dict[str, dict[str, list[str] | exp.Expr]] = {}
354
355    def _to_s(self, indent: str) -> list[str]:
356        lines = [f"{indent}Source: {self.source_name or self.name}"]
357        for name, join in self.joins.items():
358            lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
359            join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
360            if join_key:
361                lines.append(f"{indent}Key: {join_key}")
362            if join.get("condition"):
363                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
364        return lines
365
366
367class Aggregate(Step):
368    def __init__(self) -> None:
369        super().__init__()
370        self.aggregations: list[exp.Expr] = []
371        self.operands: tuple[exp.Expr, ...] = ()
372        self.group: dict[str, exp.Expr] = {}
373        self.source: str | None = None
374
375    def _to_s(self, indent: str) -> list[str]:
376        lines = [f"{indent}Aggregations:"]
377
378        for expression in self.aggregations:
379            lines.append(f"{indent}  - {expression.sql()}")
380
381        if self.group:
382            lines.append(f"{indent}Group:")
383            for expression in self.group.values():
384                lines.append(f"{indent}  - {expression.sql()}")
385        if self.condition:
386            lines.append(f"{indent}Having:")
387            lines.append(f"{indent}  - {self.condition.sql()}")
388        if self.operands:
389            lines.append(f"{indent}Operands:")
390            for expression in self.operands:
391                lines.append(f"{indent}  - {expression.sql()}")
392
393        return lines
394
395
396class Sort(Step):
397    def __init__(self) -> None:
398        super().__init__()
399        self.key = None
400
401    def _to_s(self, indent: str) -> list[str]:
402        lines = [f"{indent}Key:"]
403
404        for expression in self.key:  # type: ignore
405            lines.append(f"{indent}  - {expression.sql()}")
406
407        return lines
408
409
410class SetOperation(Step):
411    def __init__(
412        self,
413        op: type[exp.Expr],
414        left: str | None,
415        right: str | None,
416        distinct: bool = False,
417    ) -> None:
418        super().__init__()
419        self.op = op
420        self.left = left
421        self.right = right
422        self.distinct = distinct
423
424    @classmethod
425    def from_expression(
426        cls, expression: exp.Expr, ctes: dict[str, Step] | None = None
427    ) -> SetOperation:
428        assert isinstance(expression, exp.SetOperation)
429
430        left = Step.from_expression(expression.left, ctes)
431        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
432        left.name = left.name or "left"
433        right = Step.from_expression(expression.right, ctes)
434        right.name = right.name or "right"
435        step = cls(
436            op=expression.__class__,
437            left=left.name,
438            right=right.name,
439            distinct=bool(expression.args.get("distinct")),
440        )
441
442        step.add_dependency(left)
443        step.add_dependency(right)
444
445        limit = expression.args.get("limit")
446
447        if limit:
448            step.limit = int(limit.text("expression"))
449
450        return step
451
452    def _to_s(self, indent: str) -> list[str]:
453        lines = []
454        if self.distinct:
455            lines.append(f"{indent}Distinct: {self.distinct}")
456        return lines
457
458    @property
459    def type_name(self) -> str:
460        return self.op.__name__
class Plan:
13class Plan:
14    def __init__(self, expression: exp.Expr) -> None:
15        self.expression = expression.copy()
16        self.root = Step.from_expression(self.expression)
17        self._dag: dict[Step, set[Step]] = {}
18
19    @property
20    def dag(self) -> dict[Step, set[Step]]:
21        if not self._dag:
22            dag: dict[Step, set[Step]] = {}
23            nodes = {self.root}
24
25            while nodes:
26                node = nodes.pop()
27                dag[node] = set()
28
29                for dep in node.dependencies:
30                    dag[node].add(dep)
31                    nodes.add(dep)
32
33            self._dag = dag
34
35        return self._dag
36
37    @property
38    def leaves(self) -> Iterator[Step]:
39        return (node for node, deps in self.dag.items() if not deps)
40
41    def __repr__(self) -> str:
42        return f"Plan\n----\n{repr(self.root)}"
Plan(expression: sqlglot.expressions.core.Expr)
14    def __init__(self, expression: exp.Expr) -> None:
15        self.expression = expression.copy()
16        self.root = Step.from_expression(self.expression)
17        self._dag: dict[Step, set[Step]] = {}
expression
root
dag: dict[Step, set[Step]]
19    @property
20    def dag(self) -> dict[Step, set[Step]]:
21        if not self._dag:
22            dag: dict[Step, set[Step]] = {}
23            nodes = {self.root}
24
25            while nodes:
26                node = nodes.pop()
27                dag[node] = set()
28
29                for dep in node.dependencies:
30                    dag[node].add(dep)
31                    nodes.add(dep)
32
33            self._dag = dag
34
35        return self._dag
leaves: Iterator[Step]
37    @property
38    def leaves(self) -> Iterator[Step]:
39        return (node for node, deps in self.dag.items() if not deps)
class Step:
 45class Step:
 46    @classmethod
 47    def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
 48        """
 49        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 50        Note: the expression's tables and subqueries must be aliased for this method to work. For
 51        example, given the following expression:
 52
 53        SELECT
 54          x.a,
 55          SUM(x.b)
 56        FROM x AS x
 57        JOIN y AS y
 58          ON x.a = y.a
 59        GROUP BY x.a
 60
 61        the following DAG is produced (the expression IDs might differ per execution):
 62
 63        - Aggregate: x (4347984624)
 64            Context:
 65              Aggregations:
 66                - SUM(x.b)
 67              Group:
 68                - x.a
 69            Projections:
 70              - x.a
 71              - "x".""
 72            Dependencies:
 73            - Join: x (4347985296)
 74              Context:
 75                y:
 76                On: x.a = y.a
 77              Projections:
 78              Dependencies:
 79              - Scan: x (4347983136)
 80                Context:
 81                  Source: x AS x
 82                Projections:
 83              - Scan: y (4343416624)
 84                Context:
 85                  Source: y AS y
 86                Projections:
 87
 88        Args:
 89            expression: the expression to build the DAG from.
 90            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 91
 92        Returns:
 93            A Step DAG corresponding to `expression`.
 94        """
 95        ctes = ctes or {}
 96        expression = expression.unnest()
 97        with_ = expression.args.get("with_")
 98
 99        # CTEs break the mold of scope and introduce themselves to all in the context.
100        if with_:
101            ctes = ctes.copy()
102            for cte in with_.expressions:
103                step = Step.from_expression(cte.this, ctes)
104                step.name = cte.alias
105                ctes[step.name] = step  # type: ignore
106
107        from_ = expression.args.get("from_")
108
109        if isinstance(expression, exp.Select) and from_:
110            step = Scan.from_expression(from_.this, ctes)
111        elif isinstance(expression, exp.SetOperation):
112            step = SetOperation.from_expression(expression, ctes)
113        else:
114            step = Scan()
115
116        joins = expression.args.get("joins")
117
118        if joins:
119            join = Join.from_joins(joins, ctes)
120            join.name = step.name
121            join.source_name = step.name
122            join.add_dependency(step)
123            step = join
124
125        projections: list[
126            exp.Expr
127        ] = []  # final selects in this chain of steps representing a select
128        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
129        aggregations = {}
130        next_operand_name = name_sequence("_a_")
131
132        def extract_agg_operands(expression):
133            agg_funcs = tuple(expression.find_all(exp.AggFunc))
134            if agg_funcs:
135                aggregations[expression] = None
136
137            for agg in agg_funcs:
138                for operand in agg.unnest_operands():
139                    if isinstance(operand, exp.Column):
140                        continue
141                    if operand not in operands:
142                        operands[operand] = next_operand_name()
143
144                    operand.replace(exp.column(operands[operand], quoted=True))
145
146            return bool(agg_funcs)
147
148        def set_ops_and_aggs(step):
149            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
150            step.aggregations = list(aggregations)
151
152        for e in expression.expressions:
153            if e.find(exp.AggFunc):
154                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
155                extract_agg_operands(e)
156            else:
157                projections.append(e)
158
159        where = expression.args.get("where")
160
161        if where:
162            step.condition = where.this
163
164        group = expression.args.get("group")
165
166        if group or aggregations:
167            aggregate = Aggregate()
168            aggregate.source = step.name
169            aggregate.name = step.name
170
171            having = expression.args.get("having")
172
173            if having:
174                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
175                    aggregate.condition = exp.column("_h", step.name, quoted=True)
176                else:
177                    aggregate.condition = having.this
178
179            set_ops_and_aggs(aggregate)
180
181            # give aggregates names and replace projections with references to them
182            aggregate.group = {
183                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
184            }
185
186            intermediate: dict[str | exp.Expr, str] = {}
187            for k, v in aggregate.group.items():
188                intermediate[v] = k
189                if isinstance(v, exp.Column):
190                    intermediate[v.name] = k
191
192            for projection in projections:
193                for node in projection.walk():
194                    name = intermediate.get(node)
195                    if name:
196                        node.replace(exp.column(name, step.name))
197
198            if aggregate.condition:
199                for node in aggregate.condition.walk():
200                    name = intermediate.get(node) or intermediate.get(node.name)
201                    if name:
202                        node.replace(exp.column(name, step.name))
203
204            aggregate.add_dependency(step)
205            step = aggregate
206        else:
207            aggregate = None
208
209        order = expression.args.get("order")
210
211        if order:
212            if aggregate and isinstance(step, Aggregate):
213                for i, ordered in enumerate(order.expressions):
214                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
215                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
216
217                set_ops_and_aggs(aggregate)
218
219            sort = Sort()
220            sort.name = step.name
221            sort.key = order.expressions
222            sort.add_dependency(step)
223            step = sort
224
225        step.projections = projections
226
227        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
228            distinct = Aggregate()
229            distinct.source = step.name
230            distinct.name = step.name
231            distinct.group = {
232                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
233                for e in projections or expression.expressions
234            }
235            distinct.add_dependency(step)
236            step = distinct
237
238        limit = expression.args.get("limit")
239
240        if limit:
241            step.limit = int(limit.text("expression"))
242
243        return step
244
245    def __init__(self) -> None:
246        self.name: str | None = None
247        self.dependencies: set[Step] = set()
248        self.dependents: set[Step] = set()
249        self.projections: Sequence[exp.Expr] = []
250        self.limit: float = math.inf
251        self.condition: exp.Expr | None = None
252
253    def add_dependency(self, dependency: Step) -> None:
254        self.dependencies.add(dependency)
255        dependency.dependents.add(self)
256
257    def __repr__(self) -> str:
258        return self.to_s()
259
260    def to_s(self, level: int = 0) -> str:
261        indent = "  " * level
262        nested = f"{indent}    "
263
264        context = self._to_s(f"{nested}  ")
265
266        if context:
267            context = [f"{nested}Context:"] + context
268
269        lines = [
270            f"{indent}- {self.id}",
271            *context,
272            f"{nested}Projections:",
273        ]
274
275        for expression in self.projections:
276            lines.append(f"{nested}  - {expression.sql()}")
277
278        if self.condition:
279            lines.append(f"{nested}Condition: {self.condition.sql()}")
280
281        if self.limit is not math.inf:
282            lines.append(f"{nested}Limit: {self.limit}")
283
284        if self.dependencies:
285            lines.append(f"{nested}Dependencies:")
286            for dependency in self.dependencies:
287                lines.append("  " + dependency.to_s(level + 1))
288
289        return "\n".join(lines)
290
291    @property
292    def type_name(self) -> str:
293        return self.__class__.__name__
294
295    @property
296    def id(self) -> str:
297        name = self.name
298        name = f" {name}" if name else ""
299        return f"{self.type_name}:{name} ({id(self)})"
300
301    def _to_s(self, _indent: str) -> list[str]:
302        return []
@classmethod
def from_expression( cls, expression: sqlglot.expressions.core.Expr, ctes: dict[str, Step] | None = None) -> Step:
 46    @classmethod
 47    def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
 48        """
 49        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 50        Note: the expression's tables and subqueries must be aliased for this method to work. For
 51        example, given the following expression:
 52
 53        SELECT
 54          x.a,
 55          SUM(x.b)
 56        FROM x AS x
 57        JOIN y AS y
 58          ON x.a = y.a
 59        GROUP BY x.a
 60
 61        the following DAG is produced (the expression IDs might differ per execution):
 62
 63        - Aggregate: x (4347984624)
 64            Context:
 65              Aggregations:
 66                - SUM(x.b)
 67              Group:
 68                - x.a
 69            Projections:
 70              - x.a
 71              - "x".""
 72            Dependencies:
 73            - Join: x (4347985296)
 74              Context:
 75                y:
 76                On: x.a = y.a
 77              Projections:
 78              Dependencies:
 79              - Scan: x (4347983136)
 80                Context:
 81                  Source: x AS x
 82                Projections:
 83              - Scan: y (4343416624)
 84                Context:
 85                  Source: y AS y
 86                Projections:
 87
 88        Args:
 89            expression: the expression to build the DAG from.
 90            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 91
 92        Returns:
 93            A Step DAG corresponding to `expression`.
 94        """
 95        ctes = ctes or {}
 96        expression = expression.unnest()
 97        with_ = expression.args.get("with_")
 98
 99        # CTEs break the mold of scope and introduce themselves to all in the context.
100        if with_:
101            ctes = ctes.copy()
102            for cte in with_.expressions:
103                step = Step.from_expression(cte.this, ctes)
104                step.name = cte.alias
105                ctes[step.name] = step  # type: ignore
106
107        from_ = expression.args.get("from_")
108
109        if isinstance(expression, exp.Select) and from_:
110            step = Scan.from_expression(from_.this, ctes)
111        elif isinstance(expression, exp.SetOperation):
112            step = SetOperation.from_expression(expression, ctes)
113        else:
114            step = Scan()
115
116        joins = expression.args.get("joins")
117
118        if joins:
119            join = Join.from_joins(joins, ctes)
120            join.name = step.name
121            join.source_name = step.name
122            join.add_dependency(step)
123            step = join
124
125        projections: list[
126            exp.Expr
127        ] = []  # final selects in this chain of steps representing a select
128        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
129        aggregations = {}
130        next_operand_name = name_sequence("_a_")
131
132        def extract_agg_operands(expression):
133            agg_funcs = tuple(expression.find_all(exp.AggFunc))
134            if agg_funcs:
135                aggregations[expression] = None
136
137            for agg in agg_funcs:
138                for operand in agg.unnest_operands():
139                    if isinstance(operand, exp.Column):
140                        continue
141                    if operand not in operands:
142                        operands[operand] = next_operand_name()
143
144                    operand.replace(exp.column(operands[operand], quoted=True))
145
146            return bool(agg_funcs)
147
148        def set_ops_and_aggs(step):
149            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
150            step.aggregations = list(aggregations)
151
152        for e in expression.expressions:
153            if e.find(exp.AggFunc):
154                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
155                extract_agg_operands(e)
156            else:
157                projections.append(e)
158
159        where = expression.args.get("where")
160
161        if where:
162            step.condition = where.this
163
164        group = expression.args.get("group")
165
166        if group or aggregations:
167            aggregate = Aggregate()
168            aggregate.source = step.name
169            aggregate.name = step.name
170
171            having = expression.args.get("having")
172
173            if having:
174                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
175                    aggregate.condition = exp.column("_h", step.name, quoted=True)
176                else:
177                    aggregate.condition = having.this
178
179            set_ops_and_aggs(aggregate)
180
181            # give aggregates names and replace projections with references to them
182            aggregate.group = {
183                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
184            }
185
186            intermediate: dict[str | exp.Expr, str] = {}
187            for k, v in aggregate.group.items():
188                intermediate[v] = k
189                if isinstance(v, exp.Column):
190                    intermediate[v.name] = k
191
192            for projection in projections:
193                for node in projection.walk():
194                    name = intermediate.get(node)
195                    if name:
196                        node.replace(exp.column(name, step.name))
197
198            if aggregate.condition:
199                for node in aggregate.condition.walk():
200                    name = intermediate.get(node) or intermediate.get(node.name)
201                    if name:
202                        node.replace(exp.column(name, step.name))
203
204            aggregate.add_dependency(step)
205            step = aggregate
206        else:
207            aggregate = None
208
209        order = expression.args.get("order")
210
211        if order:
212            if aggregate and isinstance(step, Aggregate):
213                for i, ordered in enumerate(order.expressions):
214                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
215                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
216
217                set_ops_and_aggs(aggregate)
218
219            sort = Sort()
220            sort.name = step.name
221            sort.key = order.expressions
222            sort.add_dependency(step)
223            step = sort
224
225        step.projections = projections
226
227        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
228            distinct = Aggregate()
229            distinct.source = step.name
230            distinct.name = step.name
231            distinct.group = {
232                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
233                for e in projections or expression.expressions
234            }
235            distinct.add_dependency(step)
236            step = distinct
237
238        limit = expression.args.get("limit")
239
240        if limit:
241            step.limit = int(limit.text("expression"))
242
243        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

name: str | None
dependencies: set[Step]
dependents: set[Step]
projections: Sequence[sqlglot.expressions.core.Expr]
limit: float
condition: sqlglot.expressions.core.Expr | None
def add_dependency(self, dependency: Step) -> None:
253    def add_dependency(self, dependency: Step) -> None:
254        self.dependencies.add(dependency)
255        dependency.dependents.add(self)
def to_s(self, level: int = 0) -> str:
260    def to_s(self, level: int = 0) -> str:
261        indent = "  " * level
262        nested = f"{indent}    "
263
264        context = self._to_s(f"{nested}  ")
265
266        if context:
267            context = [f"{nested}Context:"] + context
268
269        lines = [
270            f"{indent}- {self.id}",
271            *context,
272            f"{nested}Projections:",
273        ]
274
275        for expression in self.projections:
276            lines.append(f"{nested}  - {expression.sql()}")
277
278        if self.condition:
279            lines.append(f"{nested}Condition: {self.condition.sql()}")
280
281        if self.limit is not math.inf:
282            lines.append(f"{nested}Limit: {self.limit}")
283
284        if self.dependencies:
285            lines.append(f"{nested}Dependencies:")
286            for dependency in self.dependencies:
287                lines.append("  " + dependency.to_s(level + 1))
288
289        return "\n".join(lines)
type_name: str
291    @property
292    def type_name(self) -> str:
293        return self.__class__.__name__
id: str
295    @property
296    def id(self) -> str:
297        name = self.name
298        name = f" {name}" if name else ""
299        return f"{self.type_name}:{name} ({id(self)})"
class Scan(Step):
305class Scan(Step):
306    @classmethod
307    def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
308        table = expression
309        alias_ = expression.alias_or_name
310
311        if isinstance(expression, exp.Subquery):
312            table = expression.this
313            step = Step.from_expression(table, ctes)
314            step.name = alias_
315            return step
316
317        step = Scan()
318        step.name = alias_
319        step.source = expression
320        if ctes and table.name in ctes:
321            step.add_dependency(ctes[table.name])
322
323        return step
324
325    def __init__(self) -> None:
326        super().__init__()
327        self.source: exp.Expr | None = None
328
329    def _to_s(self, indent: str) -> list[str]:
330        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
@classmethod
def from_expression( cls, expression: sqlglot.expressions.core.Expr, ctes: dict[str, Step] | None = None) -> Step:
306    @classmethod
307    def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
308        table = expression
309        alias_ = expression.alias_or_name
310
311        if isinstance(expression, exp.Subquery):
312            table = expression.this
313            step = Step.from_expression(table, ctes)
314            step.name = alias_
315            return step
316
317        step = Scan()
318        step.name = alias_
319        step.source = expression
320        if ctes and table.name in ctes:
321            step.add_dependency(ctes[table.name])
322
323        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

class Join(Step):
333class Join(Step):
334    @classmethod
335    def from_joins(cls, joins: Iterable[exp.Join], ctes: dict[str, Step] | None = None) -> Join:
336        step = Join()
337
338        for join in joins:
339            source_key, join_key, condition = join_condition(join)
340            step.joins[join.alias_or_name] = {
341                "side": join.side,  # type: ignore
342                "join_key": join_key,
343                "source_key": source_key,
344                "condition": condition,
345            }
346
347            step.add_dependency(Scan.from_expression(join.this, ctes))
348
349        return step
350
351    def __init__(self) -> None:
352        super().__init__()
353        self.source_name: str | None = None
354        self.joins: dict[str, dict[str, list[str] | exp.Expr]] = {}
355
356    def _to_s(self, indent: str) -> list[str]:
357        lines = [f"{indent}Source: {self.source_name or self.name}"]
358        for name, join in self.joins.items():
359            lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
360            join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
361            if join_key:
362                lines.append(f"{indent}Key: {join_key}")
363            if join.get("condition"):
364                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
365        return lines
@classmethod
def from_joins( cls, joins: Iterable[sqlglot.expressions.query.Join], ctes: dict[str, Step] | None = None) -> Join:
334    @classmethod
335    def from_joins(cls, joins: Iterable[exp.Join], ctes: dict[str, Step] | None = None) -> Join:
336        step = Join()
337
338        for join in joins:
339            source_key, join_key, condition = join_condition(join)
340            step.joins[join.alias_or_name] = {
341                "side": join.side,  # type: ignore
342                "join_key": join_key,
343                "source_key": source_key,
344                "condition": condition,
345            }
346
347            step.add_dependency(Scan.from_expression(join.this, ctes))
348
349        return step
source_name: str | None
joins: dict[str, dict[str, list[str] | sqlglot.expressions.core.Expr]]
class Aggregate(Step):
368class Aggregate(Step):
369    def __init__(self) -> None:
370        super().__init__()
371        self.aggregations: list[exp.Expr] = []
372        self.operands: tuple[exp.Expr, ...] = ()
373        self.group: dict[str, exp.Expr] = {}
374        self.source: str | None = None
375
376    def _to_s(self, indent: str) -> list[str]:
377        lines = [f"{indent}Aggregations:"]
378
379        for expression in self.aggregations:
380            lines.append(f"{indent}  - {expression.sql()}")
381
382        if self.group:
383            lines.append(f"{indent}Group:")
384            for expression in self.group.values():
385                lines.append(f"{indent}  - {expression.sql()}")
386        if self.condition:
387            lines.append(f"{indent}Having:")
388            lines.append(f"{indent}  - {self.condition.sql()}")
389        if self.operands:
390            lines.append(f"{indent}Operands:")
391            for expression in self.operands:
392                lines.append(f"{indent}  - {expression.sql()}")
393
394        return lines
aggregations: list[sqlglot.expressions.core.Expr]
operands: tuple[sqlglot.expressions.core.Expr, ...]
group: dict[str, sqlglot.expressions.core.Expr]
source: str | None
class Sort(Step):
397class Sort(Step):
398    def __init__(self) -> None:
399        super().__init__()
400        self.key = None
401
402    def _to_s(self, indent: str) -> list[str]:
403        lines = [f"{indent}Key:"]
404
405        for expression in self.key:  # type: ignore
406            lines.append(f"{indent}  - {expression.sql()}")
407
408        return lines
key
class SetOperation(Step):
411class SetOperation(Step):
412    def __init__(
413        self,
414        op: type[exp.Expr],
415        left: str | None,
416        right: str | None,
417        distinct: bool = False,
418    ) -> None:
419        super().__init__()
420        self.op = op
421        self.left = left
422        self.right = right
423        self.distinct = distinct
424
425    @classmethod
426    def from_expression(
427        cls, expression: exp.Expr, ctes: dict[str, Step] | None = None
428    ) -> SetOperation:
429        assert isinstance(expression, exp.SetOperation)
430
431        left = Step.from_expression(expression.left, ctes)
432        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
433        left.name = left.name or "left"
434        right = Step.from_expression(expression.right, ctes)
435        right.name = right.name or "right"
436        step = cls(
437            op=expression.__class__,
438            left=left.name,
439            right=right.name,
440            distinct=bool(expression.args.get("distinct")),
441        )
442
443        step.add_dependency(left)
444        step.add_dependency(right)
445
446        limit = expression.args.get("limit")
447
448        if limit:
449            step.limit = int(limit.text("expression"))
450
451        return step
452
453    def _to_s(self, indent: str) -> list[str]:
454        lines = []
455        if self.distinct:
456            lines.append(f"{indent}Distinct: {self.distinct}")
457        return lines
458
459    @property
460    def type_name(self) -> str:
461        return self.op.__name__
SetOperation( op: type[sqlglot.expressions.core.Expr], left: str | None, right: str | None, distinct: bool = False)
412    def __init__(
413        self,
414        op: type[exp.Expr],
415        left: str | None,
416        right: str | None,
417        distinct: bool = False,
418    ) -> None:
419        super().__init__()
420        self.op = op
421        self.left = left
422        self.right = right
423        self.distinct = distinct
op
left
right
distinct
@classmethod
def from_expression( cls, expression: sqlglot.expressions.core.Expr, ctes: dict[str, Step] | None = None) -> SetOperation:
425    @classmethod
426    def from_expression(
427        cls, expression: exp.Expr, ctes: dict[str, Step] | None = None
428    ) -> SetOperation:
429        assert isinstance(expression, exp.SetOperation)
430
431        left = Step.from_expression(expression.left, ctes)
432        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
433        left.name = left.name or "left"
434        right = Step.from_expression(expression.right, ctes)
435        right.name = right.name or "right"
436        step = cls(
437            op=expression.__class__,
438            left=left.name,
439            right=right.name,
440            distinct=bool(expression.args.get("distinct")),
441        )
442
443        step.add_dependency(left)
444        step.add_dependency(right)
445
446        limit = expression.args.get("limit")
447
448        if limit:
449            step.limit = int(limit.text("expression"))
450
451        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

type_name: str
459    @property
460    def type_name(self) -> str:
461        return self.op.__name__