Edit on GitHub

sqlglot.planner

  1from __future__ import annotations
  2
  3import math
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.helper import name_sequence
  8from sqlglot.optimizer.eliminate_joins import join_condition
  9
 10
 11class Plan:
 12    def __init__(self, expression: exp.Expression) -> None:
 13        self.expression = expression.copy()
 14        self.root = Step.from_expression(self.expression)
 15        self._dag: t.Dict[Step, t.Set[Step]] = {}
 16
 17    @property
 18    def dag(self) -> t.Dict[Step, t.Set[Step]]:
 19        if not self._dag:
 20            dag: t.Dict[Step, t.Set[Step]] = {}
 21            nodes = {self.root}
 22
 23            while nodes:
 24                node = nodes.pop()
 25                dag[node] = set()
 26
 27                for dep in node.dependencies:
 28                    dag[node].add(dep)
 29                    nodes.add(dep)
 30
 31            self._dag = dag
 32
 33        return self._dag
 34
 35    @property
 36    def leaves(self) -> t.Iterator[Step]:
 37        return (node for node, deps in self.dag.items() if not deps)
 38
 39    def __repr__(self) -> str:
 40        return f"Plan\n----\n{repr(self.root)}"
 41
 42
 43class Step:
 44    @classmethod
 45    def from_expression(
 46        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 47    ) -> Step:
 48        """
 49        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 50        Note: the expression's tables and subqueries must be aliased for this method to work. For
 51        example, given the following expression:
 52
 53        SELECT
 54          x.a,
 55          SUM(x.b)
 56        FROM x AS x
 57        JOIN y AS y
 58          ON x.a = y.a
 59        GROUP BY x.a
 60
 61        the following DAG is produced (the expression IDs might differ per execution):
 62
 63        - Aggregate: x (4347984624)
 64            Context:
 65              Aggregations:
 66                - SUM(x.b)
 67              Group:
 68                - x.a
 69            Projections:
 70              - x.a
 71              - "x".""
 72            Dependencies:
 73            - Join: x (4347985296)
 74              Context:
 75                y:
 76                On: x.a = y.a
 77              Projections:
 78              Dependencies:
 79              - Scan: x (4347983136)
 80                Context:
 81                  Source: x AS x
 82                Projections:
 83              - Scan: y (4343416624)
 84                Context:
 85                  Source: y AS y
 86                Projections:
 87
 88        Args:
 89            expression: the expression to build the DAG from.
 90            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 91
 92        Returns:
 93            A Step DAG corresponding to `expression`.
 94        """
 95        ctes = ctes or {}
 96        expression = expression.unnest()
 97        with_ = expression.args.get("with_")
 98
 99        # CTEs break the mold of scope and introduce themselves to all in the context.
100        if with_:
101            ctes = ctes.copy()
102            for cte in with_.expressions:
103                step = Step.from_expression(cte.this, ctes)
104                step.name = cte.alias
105                ctes[step.name] = step  # type: ignore
106
107        from_ = expression.args.get("from_")
108
109        if isinstance(expression, exp.Select) and from_:
110            step = Scan.from_expression(from_.this, ctes)
111        elif isinstance(expression, exp.SetOperation):
112            step = SetOperation.from_expression(expression, ctes)
113        else:
114            step = Scan()
115
116        joins = expression.args.get("joins")
117
118        if joins:
119            join = Join.from_joins(joins, ctes)
120            join.name = step.name
121            join.source_name = step.name
122            join.add_dependency(step)
123            step = join
124
125        projections: t.List[
126            exp.Expression
127        ] = []  # final selects in this chain of steps representing a select
128        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
129        aggregations = {}
130        next_operand_name = name_sequence("_a_")
131
132        def extract_agg_operands(expression):
133            agg_funcs = tuple(expression.find_all(exp.AggFunc))
134            if agg_funcs:
135                aggregations[expression] = None
136
137            for agg in agg_funcs:
138                for operand in agg.unnest_operands():
139                    if isinstance(operand, exp.Column):
140                        continue
141                    if operand not in operands:
142                        operands[operand] = next_operand_name()
143
144                    operand.replace(exp.column(operands[operand], quoted=True))
145
146            return bool(agg_funcs)
147
148        def set_ops_and_aggs(step):
149            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
150            step.aggregations = list(aggregations)
151
152        for e in expression.expressions:
153            if e.find(exp.AggFunc):
154                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
155                extract_agg_operands(e)
156            else:
157                projections.append(e)
158
159        where = expression.args.get("where")
160
161        if where:
162            step.condition = where.this
163
164        group = expression.args.get("group")
165
166        if group or aggregations:
167            aggregate = Aggregate()
168            aggregate.source = step.name
169            aggregate.name = step.name
170
171            having = expression.args.get("having")
172
173            if having:
174                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
175                    aggregate.condition = exp.column("_h", step.name, quoted=True)
176                else:
177                    aggregate.condition = having.this
178
179            set_ops_and_aggs(aggregate)
180
181            # give aggregates names and replace projections with references to them
182            aggregate.group = {
183                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
184            }
185
186            intermediate: t.Dict[str | exp.Expression, str] = {}
187            for k, v in aggregate.group.items():
188                intermediate[v] = k
189                if isinstance(v, exp.Column):
190                    intermediate[v.name] = k
191
192            for projection in projections:
193                for node in projection.walk():
194                    name = intermediate.get(node)
195                    if name:
196                        node.replace(exp.column(name, step.name))
197
198            if aggregate.condition:
199                for node in aggregate.condition.walk():
200                    name = intermediate.get(node) or intermediate.get(node.name)
201                    if name:
202                        node.replace(exp.column(name, step.name))
203
204            aggregate.add_dependency(step)
205            step = aggregate
206        else:
207            aggregate = None
208
209        order = expression.args.get("order")
210
211        if order:
212            if aggregate and isinstance(step, Aggregate):
213                for i, ordered in enumerate(order.expressions):
214                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
215                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
216
217                set_ops_and_aggs(aggregate)
218
219            sort = Sort()
220            sort.name = step.name
221            sort.key = order.expressions
222            sort.add_dependency(step)
223            step = sort
224
225        step.projections = projections
226
227        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
228            distinct = Aggregate()
229            distinct.source = step.name
230            distinct.name = step.name
231            distinct.group = {
232                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
233                for e in projections or expression.expressions
234            }
235            distinct.add_dependency(step)
236            step = distinct
237
238        limit = expression.args.get("limit")
239
240        if limit:
241            step.limit = int(limit.text("expression"))
242
243        return step
244
245    def __init__(self) -> None:
246        self.name: t.Optional[str] = None
247        self.dependencies: t.Set[Step] = set()
248        self.dependents: t.Set[Step] = set()
249        self.projections: t.Sequence[exp.Expression] = []
250        self.limit: float = math.inf
251        self.condition: t.Optional[exp.Expression] = None
252
253    def add_dependency(self, dependency: Step) -> None:
254        self.dependencies.add(dependency)
255        dependency.dependents.add(self)
256
257    def __repr__(self) -> str:
258        return self.to_s()
259
260    def to_s(self, level: int = 0) -> str:
261        indent = "  " * level
262        nested = f"{indent}    "
263
264        context = self._to_s(f"{nested}  ")
265
266        if context:
267            context = [f"{nested}Context:"] + context
268
269        lines = [
270            f"{indent}- {self.id}",
271            *context,
272            f"{nested}Projections:",
273        ]
274
275        for expression in self.projections:
276            lines.append(f"{nested}  - {expression.sql()}")
277
278        if self.condition:
279            lines.append(f"{nested}Condition: {self.condition.sql()}")
280
281        if self.limit is not math.inf:
282            lines.append(f"{nested}Limit: {self.limit}")
283
284        if self.dependencies:
285            lines.append(f"{nested}Dependencies:")
286            for dependency in self.dependencies:
287                lines.append("  " + dependency.to_s(level + 1))
288
289        return "\n".join(lines)
290
291    @property
292    def type_name(self) -> str:
293        return self.__class__.__name__
294
295    @property
296    def id(self) -> str:
297        name = self.name
298        name = f" {name}" if name else ""
299        return f"{self.type_name}:{name} ({id(self)})"
300
301    def _to_s(self, _indent: str) -> t.List[str]:
302        return []
303
304
305class Scan(Step):
306    @classmethod
307    def from_expression(
308        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
309    ) -> Step:
310        table = expression
311        alias_ = expression.alias_or_name
312
313        if isinstance(expression, exp.Subquery):
314            table = expression.this
315            step = Step.from_expression(table, ctes)
316            step.name = alias_
317            return step
318
319        step = Scan()
320        step.name = alias_
321        step.source = expression
322        if ctes and table.name in ctes:
323            step.add_dependency(ctes[table.name])
324
325        return step
326
327    def __init__(self) -> None:
328        super().__init__()
329        self.source: t.Optional[exp.Expression] = None
330
331    def _to_s(self, indent: str) -> t.List[str]:
332        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
333
334
335class Join(Step):
336    @classmethod
337    def from_joins(
338        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
339    ) -> Join:
340        step = Join()
341
342        for join in joins:
343            source_key, join_key, condition = join_condition(join)
344            step.joins[join.alias_or_name] = {
345                "side": join.side,  # type: ignore
346                "join_key": join_key,
347                "source_key": source_key,
348                "condition": condition,
349            }
350
351            step.add_dependency(Scan.from_expression(join.this, ctes))
352
353        return step
354
355    def __init__(self) -> None:
356        super().__init__()
357        self.source_name: t.Optional[str] = None
358        self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
359
360    def _to_s(self, indent: str) -> t.List[str]:
361        lines = [f"{indent}Source: {self.source_name or self.name}"]
362        for name, join in self.joins.items():
363            lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
364            join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
365            if join_key:
366                lines.append(f"{indent}Key: {join_key}")
367            if join.get("condition"):
368                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
369        return lines
370
371
372class Aggregate(Step):
373    def __init__(self) -> None:
374        super().__init__()
375        self.aggregations: t.List[exp.Expression] = []
376        self.operands: t.Tuple[exp.Expression, ...] = ()
377        self.group: t.Dict[str, exp.Expression] = {}
378        self.source: t.Optional[str] = None
379
380    def _to_s(self, indent: str) -> t.List[str]:
381        lines = [f"{indent}Aggregations:"]
382
383        for expression in self.aggregations:
384            lines.append(f"{indent}  - {expression.sql()}")
385
386        if self.group:
387            lines.append(f"{indent}Group:")
388            for expression in self.group.values():
389                lines.append(f"{indent}  - {expression.sql()}")
390        if self.condition:
391            lines.append(f"{indent}Having:")
392            lines.append(f"{indent}  - {self.condition.sql()}")
393        if self.operands:
394            lines.append(f"{indent}Operands:")
395            for expression in self.operands:
396                lines.append(f"{indent}  - {expression.sql()}")
397
398        return lines
399
400
401class Sort(Step):
402    def __init__(self) -> None:
403        super().__init__()
404        self.key = None
405
406    def _to_s(self, indent: str) -> t.List[str]:
407        lines = [f"{indent}Key:"]
408
409        for expression in self.key:  # type: ignore
410            lines.append(f"{indent}  - {expression.sql()}")
411
412        return lines
413
414
415class SetOperation(Step):
416    def __init__(
417        self,
418        op: t.Type[exp.Expression],
419        left: str | None,
420        right: str | None,
421        distinct: bool = False,
422    ) -> None:
423        super().__init__()
424        self.op = op
425        self.left = left
426        self.right = right
427        self.distinct = distinct
428
429    @classmethod
430    def from_expression(
431        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
432    ) -> SetOperation:
433        assert isinstance(expression, exp.SetOperation)
434
435        left = Step.from_expression(expression.left, ctes)
436        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
437        left.name = left.name or "left"
438        right = Step.from_expression(expression.right, ctes)
439        right.name = right.name or "right"
440        step = cls(
441            op=expression.__class__,
442            left=left.name,
443            right=right.name,
444            distinct=bool(expression.args.get("distinct")),
445        )
446
447        step.add_dependency(left)
448        step.add_dependency(right)
449
450        limit = expression.args.get("limit")
451
452        if limit:
453            step.limit = int(limit.text("expression"))
454
455        return step
456
457    def _to_s(self, indent: str) -> t.List[str]:
458        lines = []
459        if self.distinct:
460            lines.append(f"{indent}Distinct: {self.distinct}")
461        return lines
462
463    @property
464    def type_name(self) -> str:
465        return self.op.__name__
class Plan:
12class Plan:
13    def __init__(self, expression: exp.Expression) -> None:
14        self.expression = expression.copy()
15        self.root = Step.from_expression(self.expression)
16        self._dag: t.Dict[Step, t.Set[Step]] = {}
17
18    @property
19    def dag(self) -> t.Dict[Step, t.Set[Step]]:
20        if not self._dag:
21            dag: t.Dict[Step, t.Set[Step]] = {}
22            nodes = {self.root}
23
24            while nodes:
25                node = nodes.pop()
26                dag[node] = set()
27
28                for dep in node.dependencies:
29                    dag[node].add(dep)
30                    nodes.add(dep)
31
32            self._dag = dag
33
34        return self._dag
35
36    @property
37    def leaves(self) -> t.Iterator[Step]:
38        return (node for node, deps in self.dag.items() if not deps)
39
40    def __repr__(self) -> str:
41        return f"Plan\n----\n{repr(self.root)}"
Plan(expression: sqlglot.expressions.Expression)
13    def __init__(self, expression: exp.Expression) -> None:
14        self.expression = expression.copy()
15        self.root = Step.from_expression(self.expression)
16        self._dag: t.Dict[Step, t.Set[Step]] = {}
expression
root
dag: Dict[Step, Set[Step]]
18    @property
19    def dag(self) -> t.Dict[Step, t.Set[Step]]:
20        if not self._dag:
21            dag: t.Dict[Step, t.Set[Step]] = {}
22            nodes = {self.root}
23
24            while nodes:
25                node = nodes.pop()
26                dag[node] = set()
27
28                for dep in node.dependencies:
29                    dag[node].add(dep)
30                    nodes.add(dep)
31
32            self._dag = dag
33
34        return self._dag
leaves: Iterator[Step]
36    @property
37    def leaves(self) -> t.Iterator[Step]:
38        return (node for node, deps in self.dag.items() if not deps)
class Step:
 44class Step:
 45    @classmethod
 46    def from_expression(
 47        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 48    ) -> Step:
 49        """
 50        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 51        Note: the expression's tables and subqueries must be aliased for this method to work. For
 52        example, given the following expression:
 53
 54        SELECT
 55          x.a,
 56          SUM(x.b)
 57        FROM x AS x
 58        JOIN y AS y
 59          ON x.a = y.a
 60        GROUP BY x.a
 61
 62        the following DAG is produced (the expression IDs might differ per execution):
 63
 64        - Aggregate: x (4347984624)
 65            Context:
 66              Aggregations:
 67                - SUM(x.b)
 68              Group:
 69                - x.a
 70            Projections:
 71              - x.a
 72              - "x".""
 73            Dependencies:
 74            - Join: x (4347985296)
 75              Context:
 76                y:
 77                On: x.a = y.a
 78              Projections:
 79              Dependencies:
 80              - Scan: x (4347983136)
 81                Context:
 82                  Source: x AS x
 83                Projections:
 84              - Scan: y (4343416624)
 85                Context:
 86                  Source: y AS y
 87                Projections:
 88
 89        Args:
 90            expression: the expression to build the DAG from.
 91            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 92
 93        Returns:
 94            A Step DAG corresponding to `expression`.
 95        """
 96        ctes = ctes or {}
 97        expression = expression.unnest()
 98        with_ = expression.args.get("with_")
 99
100        # CTEs break the mold of scope and introduce themselves to all in the context.
101        if with_:
102            ctes = ctes.copy()
103            for cte in with_.expressions:
104                step = Step.from_expression(cte.this, ctes)
105                step.name = cte.alias
106                ctes[step.name] = step  # type: ignore
107
108        from_ = expression.args.get("from_")
109
110        if isinstance(expression, exp.Select) and from_:
111            step = Scan.from_expression(from_.this, ctes)
112        elif isinstance(expression, exp.SetOperation):
113            step = SetOperation.from_expression(expression, ctes)
114        else:
115            step = Scan()
116
117        joins = expression.args.get("joins")
118
119        if joins:
120            join = Join.from_joins(joins, ctes)
121            join.name = step.name
122            join.source_name = step.name
123            join.add_dependency(step)
124            step = join
125
126        projections: t.List[
127            exp.Expression
128        ] = []  # final selects in this chain of steps representing a select
129        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
130        aggregations = {}
131        next_operand_name = name_sequence("_a_")
132
133        def extract_agg_operands(expression):
134            agg_funcs = tuple(expression.find_all(exp.AggFunc))
135            if agg_funcs:
136                aggregations[expression] = None
137
138            for agg in agg_funcs:
139                for operand in agg.unnest_operands():
140                    if isinstance(operand, exp.Column):
141                        continue
142                    if operand not in operands:
143                        operands[operand] = next_operand_name()
144
145                    operand.replace(exp.column(operands[operand], quoted=True))
146
147            return bool(agg_funcs)
148
149        def set_ops_and_aggs(step):
150            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
151            step.aggregations = list(aggregations)
152
153        for e in expression.expressions:
154            if e.find(exp.AggFunc):
155                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
156                extract_agg_operands(e)
157            else:
158                projections.append(e)
159
160        where = expression.args.get("where")
161
162        if where:
163            step.condition = where.this
164
165        group = expression.args.get("group")
166
167        if group or aggregations:
168            aggregate = Aggregate()
169            aggregate.source = step.name
170            aggregate.name = step.name
171
172            having = expression.args.get("having")
173
174            if having:
175                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
176                    aggregate.condition = exp.column("_h", step.name, quoted=True)
177                else:
178                    aggregate.condition = having.this
179
180            set_ops_and_aggs(aggregate)
181
182            # give aggregates names and replace projections with references to them
183            aggregate.group = {
184                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
185            }
186
187            intermediate: t.Dict[str | exp.Expression, str] = {}
188            for k, v in aggregate.group.items():
189                intermediate[v] = k
190                if isinstance(v, exp.Column):
191                    intermediate[v.name] = k
192
193            for projection in projections:
194                for node in projection.walk():
195                    name = intermediate.get(node)
196                    if name:
197                        node.replace(exp.column(name, step.name))
198
199            if aggregate.condition:
200                for node in aggregate.condition.walk():
201                    name = intermediate.get(node) or intermediate.get(node.name)
202                    if name:
203                        node.replace(exp.column(name, step.name))
204
205            aggregate.add_dependency(step)
206            step = aggregate
207        else:
208            aggregate = None
209
210        order = expression.args.get("order")
211
212        if order:
213            if aggregate and isinstance(step, Aggregate):
214                for i, ordered in enumerate(order.expressions):
215                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
216                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
217
218                set_ops_and_aggs(aggregate)
219
220            sort = Sort()
221            sort.name = step.name
222            sort.key = order.expressions
223            sort.add_dependency(step)
224            step = sort
225
226        step.projections = projections
227
228        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
229            distinct = Aggregate()
230            distinct.source = step.name
231            distinct.name = step.name
232            distinct.group = {
233                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
234                for e in projections or expression.expressions
235            }
236            distinct.add_dependency(step)
237            step = distinct
238
239        limit = expression.args.get("limit")
240
241        if limit:
242            step.limit = int(limit.text("expression"))
243
244        return step
245
246    def __init__(self) -> None:
247        self.name: t.Optional[str] = None
248        self.dependencies: t.Set[Step] = set()
249        self.dependents: t.Set[Step] = set()
250        self.projections: t.Sequence[exp.Expression] = []
251        self.limit: float = math.inf
252        self.condition: t.Optional[exp.Expression] = None
253
254    def add_dependency(self, dependency: Step) -> None:
255        self.dependencies.add(dependency)
256        dependency.dependents.add(self)
257
258    def __repr__(self) -> str:
259        return self.to_s()
260
261    def to_s(self, level: int = 0) -> str:
262        indent = "  " * level
263        nested = f"{indent}    "
264
265        context = self._to_s(f"{nested}  ")
266
267        if context:
268            context = [f"{nested}Context:"] + context
269
270        lines = [
271            f"{indent}- {self.id}",
272            *context,
273            f"{nested}Projections:",
274        ]
275
276        for expression in self.projections:
277            lines.append(f"{nested}  - {expression.sql()}")
278
279        if self.condition:
280            lines.append(f"{nested}Condition: {self.condition.sql()}")
281
282        if self.limit is not math.inf:
283            lines.append(f"{nested}Limit: {self.limit}")
284
285        if self.dependencies:
286            lines.append(f"{nested}Dependencies:")
287            for dependency in self.dependencies:
288                lines.append("  " + dependency.to_s(level + 1))
289
290        return "\n".join(lines)
291
292    @property
293    def type_name(self) -> str:
294        return self.__class__.__name__
295
296    @property
297    def id(self) -> str:
298        name = self.name
299        name = f" {name}" if name else ""
300        return f"{self.type_name}:{name} ({id(self)})"
301
302    def _to_s(self, _indent: str) -> t.List[str]:
303        return []
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, Step]] = None) -> Step:
 45    @classmethod
 46    def from_expression(
 47        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 48    ) -> Step:
 49        """
 50        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 51        Note: the expression's tables and subqueries must be aliased for this method to work. For
 52        example, given the following expression:
 53
 54        SELECT
 55          x.a,
 56          SUM(x.b)
 57        FROM x AS x
 58        JOIN y AS y
 59          ON x.a = y.a
 60        GROUP BY x.a
 61
 62        the following DAG is produced (the expression IDs might differ per execution):
 63
 64        - Aggregate: x (4347984624)
 65            Context:
 66              Aggregations:
 67                - SUM(x.b)
 68              Group:
 69                - x.a
 70            Projections:
 71              - x.a
 72              - "x".""
 73            Dependencies:
 74            - Join: x (4347985296)
 75              Context:
 76                y:
 77                On: x.a = y.a
 78              Projections:
 79              Dependencies:
 80              - Scan: x (4347983136)
 81                Context:
 82                  Source: x AS x
 83                Projections:
 84              - Scan: y (4343416624)
 85                Context:
 86                  Source: y AS y
 87                Projections:
 88
 89        Args:
 90            expression: the expression to build the DAG from.
 91            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 92
 93        Returns:
 94            A Step DAG corresponding to `expression`.
 95        """
 96        ctes = ctes or {}
 97        expression = expression.unnest()
 98        with_ = expression.args.get("with_")
 99
100        # CTEs break the mold of scope and introduce themselves to all in the context.
101        if with_:
102            ctes = ctes.copy()
103            for cte in with_.expressions:
104                step = Step.from_expression(cte.this, ctes)
105                step.name = cte.alias
106                ctes[step.name] = step  # type: ignore
107
108        from_ = expression.args.get("from_")
109
110        if isinstance(expression, exp.Select) and from_:
111            step = Scan.from_expression(from_.this, ctes)
112        elif isinstance(expression, exp.SetOperation):
113            step = SetOperation.from_expression(expression, ctes)
114        else:
115            step = Scan()
116
117        joins = expression.args.get("joins")
118
119        if joins:
120            join = Join.from_joins(joins, ctes)
121            join.name = step.name
122            join.source_name = step.name
123            join.add_dependency(step)
124            step = join
125
126        projections: t.List[
127            exp.Expression
128        ] = []  # final selects in this chain of steps representing a select
129        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
130        aggregations = {}
131        next_operand_name = name_sequence("_a_")
132
133        def extract_agg_operands(expression):
134            agg_funcs = tuple(expression.find_all(exp.AggFunc))
135            if agg_funcs:
136                aggregations[expression] = None
137
138            for agg in agg_funcs:
139                for operand in agg.unnest_operands():
140                    if isinstance(operand, exp.Column):
141                        continue
142                    if operand not in operands:
143                        operands[operand] = next_operand_name()
144
145                    operand.replace(exp.column(operands[operand], quoted=True))
146
147            return bool(agg_funcs)
148
149        def set_ops_and_aggs(step):
150            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
151            step.aggregations = list(aggregations)
152
153        for e in expression.expressions:
154            if e.find(exp.AggFunc):
155                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
156                extract_agg_operands(e)
157            else:
158                projections.append(e)
159
160        where = expression.args.get("where")
161
162        if where:
163            step.condition = where.this
164
165        group = expression.args.get("group")
166
167        if group or aggregations:
168            aggregate = Aggregate()
169            aggregate.source = step.name
170            aggregate.name = step.name
171
172            having = expression.args.get("having")
173
174            if having:
175                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
176                    aggregate.condition = exp.column("_h", step.name, quoted=True)
177                else:
178                    aggregate.condition = having.this
179
180            set_ops_and_aggs(aggregate)
181
182            # give aggregates names and replace projections with references to them
183            aggregate.group = {
184                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
185            }
186
187            intermediate: t.Dict[str | exp.Expression, str] = {}
188            for k, v in aggregate.group.items():
189                intermediate[v] = k
190                if isinstance(v, exp.Column):
191                    intermediate[v.name] = k
192
193            for projection in projections:
194                for node in projection.walk():
195                    name = intermediate.get(node)
196                    if name:
197                        node.replace(exp.column(name, step.name))
198
199            if aggregate.condition:
200                for node in aggregate.condition.walk():
201                    name = intermediate.get(node) or intermediate.get(node.name)
202                    if name:
203                        node.replace(exp.column(name, step.name))
204
205            aggregate.add_dependency(step)
206            step = aggregate
207        else:
208            aggregate = None
209
210        order = expression.args.get("order")
211
212        if order:
213            if aggregate and isinstance(step, Aggregate):
214                for i, ordered in enumerate(order.expressions):
215                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
216                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
217
218                set_ops_and_aggs(aggregate)
219
220            sort = Sort()
221            sort.name = step.name
222            sort.key = order.expressions
223            sort.add_dependency(step)
224            step = sort
225
226        step.projections = projections
227
228        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
229            distinct = Aggregate()
230            distinct.source = step.name
231            distinct.name = step.name
232            distinct.group = {
233                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
234                for e in projections or expression.expressions
235            }
236            distinct.add_dependency(step)
237            step = distinct
238
239        limit = expression.args.get("limit")
240
241        if limit:
242            step.limit = int(limit.text("expression"))
243
244        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

name: Optional[str]
dependencies: Set[Step]
dependents: Set[Step]
projections: Sequence[sqlglot.expressions.Expression]
limit: float
condition: Optional[sqlglot.expressions.Expression]
def add_dependency(self, dependency: Step) -> None:
254    def add_dependency(self, dependency: Step) -> None:
255        self.dependencies.add(dependency)
256        dependency.dependents.add(self)
def to_s(self, level: int = 0) -> str:
261    def to_s(self, level: int = 0) -> str:
262        indent = "  " * level
263        nested = f"{indent}    "
264
265        context = self._to_s(f"{nested}  ")
266
267        if context:
268            context = [f"{nested}Context:"] + context
269
270        lines = [
271            f"{indent}- {self.id}",
272            *context,
273            f"{nested}Projections:",
274        ]
275
276        for expression in self.projections:
277            lines.append(f"{nested}  - {expression.sql()}")
278
279        if self.condition:
280            lines.append(f"{nested}Condition: {self.condition.sql()}")
281
282        if self.limit is not math.inf:
283            lines.append(f"{nested}Limit: {self.limit}")
284
285        if self.dependencies:
286            lines.append(f"{nested}Dependencies:")
287            for dependency in self.dependencies:
288                lines.append("  " + dependency.to_s(level + 1))
289
290        return "\n".join(lines)
type_name: str
292    @property
293    def type_name(self) -> str:
294        return self.__class__.__name__
id: str
296    @property
297    def id(self) -> str:
298        name = self.name
299        name = f" {name}" if name else ""
300        return f"{self.type_name}:{name} ({id(self)})"
class Scan(Step):
306class Scan(Step):
307    @classmethod
308    def from_expression(
309        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
310    ) -> Step:
311        table = expression
312        alias_ = expression.alias_or_name
313
314        if isinstance(expression, exp.Subquery):
315            table = expression.this
316            step = Step.from_expression(table, ctes)
317            step.name = alias_
318            return step
319
320        step = Scan()
321        step.name = alias_
322        step.source = expression
323        if ctes and table.name in ctes:
324            step.add_dependency(ctes[table.name])
325
326        return step
327
328    def __init__(self) -> None:
329        super().__init__()
330        self.source: t.Optional[exp.Expression] = None
331
332    def _to_s(self, indent: str) -> t.List[str]:
333        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, Step]] = None) -> Step:
307    @classmethod
308    def from_expression(
309        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
310    ) -> Step:
311        table = expression
312        alias_ = expression.alias_or_name
313
314        if isinstance(expression, exp.Subquery):
315            table = expression.this
316            step = Step.from_expression(table, ctes)
317            step.name = alias_
318            return step
319
320        step = Scan()
321        step.name = alias_
322        step.source = expression
323        if ctes and table.name in ctes:
324            step.add_dependency(ctes[table.name])
325
326        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

source: Optional[sqlglot.expressions.Expression]
class Join(Step):
336class Join(Step):
337    @classmethod
338    def from_joins(
339        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
340    ) -> Join:
341        step = Join()
342
343        for join in joins:
344            source_key, join_key, condition = join_condition(join)
345            step.joins[join.alias_or_name] = {
346                "side": join.side,  # type: ignore
347                "join_key": join_key,
348                "source_key": source_key,
349                "condition": condition,
350            }
351
352            step.add_dependency(Scan.from_expression(join.this, ctes))
353
354        return step
355
356    def __init__(self) -> None:
357        super().__init__()
358        self.source_name: t.Optional[str] = None
359        self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
360
361    def _to_s(self, indent: str) -> t.List[str]:
362        lines = [f"{indent}Source: {self.source_name or self.name}"]
363        for name, join in self.joins.items():
364            lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
365            join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
366            if join_key:
367                lines.append(f"{indent}Key: {join_key}")
368            if join.get("condition"):
369                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
370        return lines
@classmethod
def from_joins( cls, joins: Iterable[sqlglot.expressions.Join], ctes: Optional[Dict[str, Step]] = None) -> Join:
337    @classmethod
338    def from_joins(
339        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
340    ) -> Join:
341        step = Join()
342
343        for join in joins:
344            source_key, join_key, condition = join_condition(join)
345            step.joins[join.alias_or_name] = {
346                "side": join.side,  # type: ignore
347                "join_key": join_key,
348                "source_key": source_key,
349                "condition": condition,
350            }
351
352            step.add_dependency(Scan.from_expression(join.this, ctes))
353
354        return step
source_name: Optional[str]
joins: Dict[str, Dict[str, Union[List[str], sqlglot.expressions.Expression]]]
class Aggregate(Step):
373class Aggregate(Step):
374    def __init__(self) -> None:
375        super().__init__()
376        self.aggregations: t.List[exp.Expression] = []
377        self.operands: t.Tuple[exp.Expression, ...] = ()
378        self.group: t.Dict[str, exp.Expression] = {}
379        self.source: t.Optional[str] = None
380
381    def _to_s(self, indent: str) -> t.List[str]:
382        lines = [f"{indent}Aggregations:"]
383
384        for expression in self.aggregations:
385            lines.append(f"{indent}  - {expression.sql()}")
386
387        if self.group:
388            lines.append(f"{indent}Group:")
389            for expression in self.group.values():
390                lines.append(f"{indent}  - {expression.sql()}")
391        if self.condition:
392            lines.append(f"{indent}Having:")
393            lines.append(f"{indent}  - {self.condition.sql()}")
394        if self.operands:
395            lines.append(f"{indent}Operands:")
396            for expression in self.operands:
397                lines.append(f"{indent}  - {expression.sql()}")
398
399        return lines
aggregations: List[sqlglot.expressions.Expression]
operands: Tuple[sqlglot.expressions.Expression, ...]
group: Dict[str, sqlglot.expressions.Expression]
source: Optional[str]
class Sort(Step):
402class Sort(Step):
403    def __init__(self) -> None:
404        super().__init__()
405        self.key = None
406
407    def _to_s(self, indent: str) -> t.List[str]:
408        lines = [f"{indent}Key:"]
409
410        for expression in self.key:  # type: ignore
411            lines.append(f"{indent}  - {expression.sql()}")
412
413        return lines
key
class SetOperation(Step):
416class SetOperation(Step):
417    def __init__(
418        self,
419        op: t.Type[exp.Expression],
420        left: str | None,
421        right: str | None,
422        distinct: bool = False,
423    ) -> None:
424        super().__init__()
425        self.op = op
426        self.left = left
427        self.right = right
428        self.distinct = distinct
429
430    @classmethod
431    def from_expression(
432        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
433    ) -> SetOperation:
434        assert isinstance(expression, exp.SetOperation)
435
436        left = Step.from_expression(expression.left, ctes)
437        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
438        left.name = left.name or "left"
439        right = Step.from_expression(expression.right, ctes)
440        right.name = right.name or "right"
441        step = cls(
442            op=expression.__class__,
443            left=left.name,
444            right=right.name,
445            distinct=bool(expression.args.get("distinct")),
446        )
447
448        step.add_dependency(left)
449        step.add_dependency(right)
450
451        limit = expression.args.get("limit")
452
453        if limit:
454            step.limit = int(limit.text("expression"))
455
456        return step
457
458    def _to_s(self, indent: str) -> t.List[str]:
459        lines = []
460        if self.distinct:
461            lines.append(f"{indent}Distinct: {self.distinct}")
462        return lines
463
464    @property
465    def type_name(self) -> str:
466        return self.op.__name__
SetOperation( op: Type[sqlglot.expressions.Expression], left: str | None, right: str | None, distinct: bool = False)
417    def __init__(
418        self,
419        op: t.Type[exp.Expression],
420        left: str | None,
421        right: str | None,
422        distinct: bool = False,
423    ) -> None:
424        super().__init__()
425        self.op = op
426        self.left = left
427        self.right = right
428        self.distinct = distinct
op
left
right
distinct
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, Step]] = None) -> SetOperation:
430    @classmethod
431    def from_expression(
432        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
433    ) -> SetOperation:
434        assert isinstance(expression, exp.SetOperation)
435
436        left = Step.from_expression(expression.left, ctes)
437        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
438        left.name = left.name or "left"
439        right = Step.from_expression(expression.right, ctes)
440        right.name = right.name or "right"
441        step = cls(
442            op=expression.__class__,
443            left=left.name,
444            right=right.name,
445            distinct=bool(expression.args.get("distinct")),
446        )
447
448        step.add_dependency(left)
449        step.add_dependency(right)
450
451        limit = expression.args.get("limit")
452
453        if limit:
454            step.limit = int(limit.text("expression"))
455
456        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

type_name: str
464    @property
465    def type_name(self) -> str:
466        return self.op.__name__