Edit on GitHub

sqlglot.planner

  1from __future__ import annotations
  2
  3import math
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.helper import name_sequence
  8from sqlglot.optimizer.eliminate_joins import join_condition
  9
 10
 11class Plan:
 12    def __init__(self, expression: exp.Expression) -> None:
 13        self.expression = expression.copy()
 14        self.root = Step.from_expression(self.expression)
 15        self._dag: t.Dict[Step, t.Set[Step]] = {}
 16
 17    @property
 18    def dag(self) -> t.Dict[Step, t.Set[Step]]:
 19        if not self._dag:
 20            dag: t.Dict[Step, t.Set[Step]] = {}
 21            nodes = {self.root}
 22
 23            while nodes:
 24                node = nodes.pop()
 25                dag[node] = set()
 26
 27                for dep in node.dependencies:
 28                    dag[node].add(dep)
 29                    nodes.add(dep)
 30
 31            self._dag = dag
 32
 33        return self._dag
 34
 35    @property
 36    def leaves(self) -> t.Iterator[Step]:
 37        return (node for node, deps in self.dag.items() if not deps)
 38
 39    def __repr__(self) -> str:
 40        return f"Plan\n----\n{repr(self.root)}"
 41
 42
 43class Step:
 44    @classmethod
 45    def from_expression(
 46        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 47    ) -> Step:
 48        """
 49        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 50        Note: the expression's tables and subqueries must be aliased for this method to work. For
 51        example, given the following expression:
 52
 53        SELECT
 54          x.a,
 55          SUM(x.b)
 56        FROM x AS x
 57        JOIN y AS y
 58          ON x.a = y.a
 59        GROUP BY x.a
 60
 61        the following DAG is produced (the expression IDs might differ per execution):
 62
 63        - Aggregate: x (4347984624)
 64            Context:
 65              Aggregations:
 66                - SUM(x.b)
 67              Group:
 68                - x.a
 69            Projections:
 70              - x.a
 71              - "x".""
 72            Dependencies:
 73            - Join: x (4347985296)
 74              Context:
 75                y:
 76                On: x.a = y.a
 77              Projections:
 78              Dependencies:
 79              - Scan: x (4347983136)
 80                Context:
 81                  Source: x AS x
 82                Projections:
 83              - Scan: y (4343416624)
 84                Context:
 85                  Source: y AS y
 86                Projections:
 87
 88        Args:
 89            expression: the expression to build the DAG from.
 90            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 91
 92        Returns:
 93            A Step DAG corresponding to `expression`.
 94        """
 95        ctes = ctes or {}
 96        expression = expression.unnest()
 97        with_ = expression.args.get("with")
 98
 99        # CTEs break the mold of scope and introduce themselves to all in the context.
100        if with_:
101            ctes = ctes.copy()
102            for cte in with_.expressions:
103                step = Step.from_expression(cte.this, ctes)
104                step.name = cte.alias
105                ctes[step.name] = step  # type: ignore
106
107        from_ = expression.args.get("from")
108
109        if isinstance(expression, exp.Select) and from_:
110            step = Scan.from_expression(from_.this, ctes)
111        elif isinstance(expression, exp.SetOperation):
112            step = SetOperation.from_expression(expression, ctes)
113        else:
114            step = Scan()
115
116        joins = expression.args.get("joins")
117
118        if joins:
119            join = Join.from_joins(joins, ctes)
120            join.name = step.name
121            join.source_name = step.name
122            join.add_dependency(step)
123            step = join
124
125        projections = []  # final selects in this chain of steps representing a select
126        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
127        aggregations = {}
128        next_operand_name = name_sequence("_a_")
129
130        def extract_agg_operands(expression):
131            agg_funcs = tuple(expression.find_all(exp.AggFunc))
132            if agg_funcs:
133                aggregations[expression] = None
134
135            for agg in agg_funcs:
136                for operand in agg.unnest_operands():
137                    if isinstance(operand, exp.Column):
138                        continue
139                    if operand not in operands:
140                        operands[operand] = next_operand_name()
141
142                    operand.replace(exp.column(operands[operand], quoted=True))
143
144            return bool(agg_funcs)
145
146        def set_ops_and_aggs(step):
147            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
148            step.aggregations = list(aggregations)
149
150        for e in expression.expressions:
151            if e.find(exp.AggFunc):
152                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
153                extract_agg_operands(e)
154            else:
155                projections.append(e)
156
157        where = expression.args.get("where")
158
159        if where:
160            step.condition = where.this
161
162        group = expression.args.get("group")
163
164        if group or aggregations:
165            aggregate = Aggregate()
166            aggregate.source = step.name
167            aggregate.name = step.name
168
169            having = expression.args.get("having")
170
171            if having:
172                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
173                    aggregate.condition = exp.column("_h", step.name, quoted=True)
174                else:
175                    aggregate.condition = having.this
176
177            set_ops_and_aggs(aggregate)
178
179            # give aggregates names and replace projections with references to them
180            aggregate.group = {
181                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
182            }
183
184            intermediate: t.Dict[str | exp.Expression, str] = {}
185            for k, v in aggregate.group.items():
186                intermediate[v] = k
187                if isinstance(v, exp.Column):
188                    intermediate[v.name] = k
189
190            for projection in projections:
191                for node in projection.walk():
192                    name = intermediate.get(node)
193                    if name:
194                        node.replace(exp.column(name, step.name))
195
196            if aggregate.condition:
197                for node in aggregate.condition.walk():
198                    name = intermediate.get(node) or intermediate.get(node.name)
199                    if name:
200                        node.replace(exp.column(name, step.name))
201
202            aggregate.add_dependency(step)
203            step = aggregate
204        else:
205            aggregate = None
206
207        order = expression.args.get("order")
208
209        if order:
210            if aggregate and isinstance(step, Aggregate):
211                for i, ordered in enumerate(order.expressions):
212                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
213                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
214
215                set_ops_and_aggs(aggregate)
216
217            sort = Sort()
218            sort.name = step.name
219            sort.key = order.expressions
220            sort.add_dependency(step)
221            step = sort
222
223        step.projections = projections
224
225        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
226            distinct = Aggregate()
227            distinct.source = step.name
228            distinct.name = step.name
229            distinct.group = {
230                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
231                for e in projections or expression.expressions
232            }
233            distinct.add_dependency(step)
234            step = distinct
235
236        limit = expression.args.get("limit")
237
238        if limit:
239            step.limit = int(limit.text("expression"))
240
241        return step
242
243    def __init__(self) -> None:
244        self.name: t.Optional[str] = None
245        self.dependencies: t.Set[Step] = set()
246        self.dependents: t.Set[Step] = set()
247        self.projections: t.Sequence[exp.Expression] = []
248        self.limit: float = math.inf
249        self.condition: t.Optional[exp.Expression] = None
250
251    def add_dependency(self, dependency: Step) -> None:
252        self.dependencies.add(dependency)
253        dependency.dependents.add(self)
254
255    def __repr__(self) -> str:
256        return self.to_s()
257
258    def to_s(self, level: int = 0) -> str:
259        indent = "  " * level
260        nested = f"{indent}    "
261
262        context = self._to_s(f"{nested}  ")
263
264        if context:
265            context = [f"{nested}Context:"] + context
266
267        lines = [
268            f"{indent}- {self.id}",
269            *context,
270            f"{nested}Projections:",
271        ]
272
273        for expression in self.projections:
274            lines.append(f"{nested}  - {expression.sql()}")
275
276        if self.condition:
277            lines.append(f"{nested}Condition: {self.condition.sql()}")
278
279        if self.limit is not math.inf:
280            lines.append(f"{nested}Limit: {self.limit}")
281
282        if self.dependencies:
283            lines.append(f"{nested}Dependencies:")
284            for dependency in self.dependencies:
285                lines.append("  " + dependency.to_s(level + 1))
286
287        return "\n".join(lines)
288
289    @property
290    def type_name(self) -> str:
291        return self.__class__.__name__
292
293    @property
294    def id(self) -> str:
295        name = self.name
296        name = f" {name}" if name else ""
297        return f"{self.type_name}:{name} ({id(self)})"
298
299    def _to_s(self, _indent: str) -> t.List[str]:
300        return []
301
302
303class Scan(Step):
304    @classmethod
305    def from_expression(
306        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
307    ) -> Step:
308        table = expression
309        alias_ = expression.alias_or_name
310
311        if isinstance(expression, exp.Subquery):
312            table = expression.this
313            step = Step.from_expression(table, ctes)
314            step.name = alias_
315            return step
316
317        step = Scan()
318        step.name = alias_
319        step.source = expression
320        if ctes and table.name in ctes:
321            step.add_dependency(ctes[table.name])
322
323        return step
324
325    def __init__(self) -> None:
326        super().__init__()
327        self.source: t.Optional[exp.Expression] = None
328
329    def _to_s(self, indent: str) -> t.List[str]:
330        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
331
332
333class Join(Step):
334    @classmethod
335    def from_joins(
336        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
337    ) -> Join:
338        step = Join()
339
340        for join in joins:
341            source_key, join_key, condition = join_condition(join)
342            step.joins[join.alias_or_name] = {
343                "side": join.side,  # type: ignore
344                "join_key": join_key,
345                "source_key": source_key,
346                "condition": condition,
347            }
348
349            step.add_dependency(Scan.from_expression(join.this, ctes))
350
351        return step
352
353    def __init__(self) -> None:
354        super().__init__()
355        self.source_name: t.Optional[str] = None
356        self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
357
358    def _to_s(self, indent: str) -> t.List[str]:
359        lines = [f"{indent}Source: {self.source_name or self.name}"]
360        for name, join in self.joins.items():
361            lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
362            join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
363            if join_key:
364                lines.append(f"{indent}Key: {join_key}")
365            if join.get("condition"):
366                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
367        return lines
368
369
370class Aggregate(Step):
371    def __init__(self) -> None:
372        super().__init__()
373        self.aggregations: t.List[exp.Expression] = []
374        self.operands: t.Tuple[exp.Expression, ...] = ()
375        self.group: t.Dict[str, exp.Expression] = {}
376        self.source: t.Optional[str] = None
377
378    def _to_s(self, indent: str) -> t.List[str]:
379        lines = [f"{indent}Aggregations:"]
380
381        for expression in self.aggregations:
382            lines.append(f"{indent}  - {expression.sql()}")
383
384        if self.group:
385            lines.append(f"{indent}Group:")
386            for expression in self.group.values():
387                lines.append(f"{indent}  - {expression.sql()}")
388        if self.condition:
389            lines.append(f"{indent}Having:")
390            lines.append(f"{indent}  - {self.condition.sql()}")
391        if self.operands:
392            lines.append(f"{indent}Operands:")
393            for expression in self.operands:
394                lines.append(f"{indent}  - {expression.sql()}")
395
396        return lines
397
398
399class Sort(Step):
400    def __init__(self) -> None:
401        super().__init__()
402        self.key = None
403
404    def _to_s(self, indent: str) -> t.List[str]:
405        lines = [f"{indent}Key:"]
406
407        for expression in self.key:  # type: ignore
408            lines.append(f"{indent}  - {expression.sql()}")
409
410        return lines
411
412
413class SetOperation(Step):
414    def __init__(
415        self,
416        op: t.Type[exp.Expression],
417        left: str | None,
418        right: str | None,
419        distinct: bool = False,
420    ) -> None:
421        super().__init__()
422        self.op = op
423        self.left = left
424        self.right = right
425        self.distinct = distinct
426
427    @classmethod
428    def from_expression(
429        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
430    ) -> SetOperation:
431        assert isinstance(expression, exp.SetOperation)
432
433        left = Step.from_expression(expression.left, ctes)
434        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
435        left.name = left.name or "left"
436        right = Step.from_expression(expression.right, ctes)
437        right.name = right.name or "right"
438        step = cls(
439            op=expression.__class__,
440            left=left.name,
441            right=right.name,
442            distinct=bool(expression.args.get("distinct")),
443        )
444
445        step.add_dependency(left)
446        step.add_dependency(right)
447
448        limit = expression.args.get("limit")
449
450        if limit:
451            step.limit = int(limit.text("expression"))
452
453        return step
454
455    def _to_s(self, indent: str) -> t.List[str]:
456        lines = []
457        if self.distinct:
458            lines.append(f"{indent}Distinct: {self.distinct}")
459        return lines
460
461    @property
462    def type_name(self) -> str:
463        return self.op.__name__
class Plan:
12class Plan:
13    def __init__(self, expression: exp.Expression) -> None:
14        self.expression = expression.copy()
15        self.root = Step.from_expression(self.expression)
16        self._dag: t.Dict[Step, t.Set[Step]] = {}
17
18    @property
19    def dag(self) -> t.Dict[Step, t.Set[Step]]:
20        if not self._dag:
21            dag: t.Dict[Step, t.Set[Step]] = {}
22            nodes = {self.root}
23
24            while nodes:
25                node = nodes.pop()
26                dag[node] = set()
27
28                for dep in node.dependencies:
29                    dag[node].add(dep)
30                    nodes.add(dep)
31
32            self._dag = dag
33
34        return self._dag
35
36    @property
37    def leaves(self) -> t.Iterator[Step]:
38        return (node for node, deps in self.dag.items() if not deps)
39
40    def __repr__(self) -> str:
41        return f"Plan\n----\n{repr(self.root)}"
Plan(expression: sqlglot.expressions.Expression)
13    def __init__(self, expression: exp.Expression) -> None:
14        self.expression = expression.copy()
15        self.root = Step.from_expression(self.expression)
16        self._dag: t.Dict[Step, t.Set[Step]] = {}
expression
root
dag: Dict[Step, Set[Step]]
18    @property
19    def dag(self) -> t.Dict[Step, t.Set[Step]]:
20        if not self._dag:
21            dag: t.Dict[Step, t.Set[Step]] = {}
22            nodes = {self.root}
23
24            while nodes:
25                node = nodes.pop()
26                dag[node] = set()
27
28                for dep in node.dependencies:
29                    dag[node].add(dep)
30                    nodes.add(dep)
31
32            self._dag = dag
33
34        return self._dag
leaves: Iterator[Step]
36    @property
37    def leaves(self) -> t.Iterator[Step]:
38        return (node for node, deps in self.dag.items() if not deps)
class Step:
 44class Step:
 45    @classmethod
 46    def from_expression(
 47        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 48    ) -> Step:
 49        """
 50        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 51        Note: the expression's tables and subqueries must be aliased for this method to work. For
 52        example, given the following expression:
 53
 54        SELECT
 55          x.a,
 56          SUM(x.b)
 57        FROM x AS x
 58        JOIN y AS y
 59          ON x.a = y.a
 60        GROUP BY x.a
 61
 62        the following DAG is produced (the expression IDs might differ per execution):
 63
 64        - Aggregate: x (4347984624)
 65            Context:
 66              Aggregations:
 67                - SUM(x.b)
 68              Group:
 69                - x.a
 70            Projections:
 71              - x.a
 72              - "x".""
 73            Dependencies:
 74            - Join: x (4347985296)
 75              Context:
 76                y:
 77                On: x.a = y.a
 78              Projections:
 79              Dependencies:
 80              - Scan: x (4347983136)
 81                Context:
 82                  Source: x AS x
 83                Projections:
 84              - Scan: y (4343416624)
 85                Context:
 86                  Source: y AS y
 87                Projections:
 88
 89        Args:
 90            expression: the expression to build the DAG from.
 91            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 92
 93        Returns:
 94            A Step DAG corresponding to `expression`.
 95        """
 96        ctes = ctes or {}
 97        expression = expression.unnest()
 98        with_ = expression.args.get("with")
 99
100        # CTEs break the mold of scope and introduce themselves to all in the context.
101        if with_:
102            ctes = ctes.copy()
103            for cte in with_.expressions:
104                step = Step.from_expression(cte.this, ctes)
105                step.name = cte.alias
106                ctes[step.name] = step  # type: ignore
107
108        from_ = expression.args.get("from")
109
110        if isinstance(expression, exp.Select) and from_:
111            step = Scan.from_expression(from_.this, ctes)
112        elif isinstance(expression, exp.SetOperation):
113            step = SetOperation.from_expression(expression, ctes)
114        else:
115            step = Scan()
116
117        joins = expression.args.get("joins")
118
119        if joins:
120            join = Join.from_joins(joins, ctes)
121            join.name = step.name
122            join.source_name = step.name
123            join.add_dependency(step)
124            step = join
125
126        projections = []  # final selects in this chain of steps representing a select
127        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
128        aggregations = {}
129        next_operand_name = name_sequence("_a_")
130
131        def extract_agg_operands(expression):
132            agg_funcs = tuple(expression.find_all(exp.AggFunc))
133            if agg_funcs:
134                aggregations[expression] = None
135
136            for agg in agg_funcs:
137                for operand in agg.unnest_operands():
138                    if isinstance(operand, exp.Column):
139                        continue
140                    if operand not in operands:
141                        operands[operand] = next_operand_name()
142
143                    operand.replace(exp.column(operands[operand], quoted=True))
144
145            return bool(agg_funcs)
146
147        def set_ops_and_aggs(step):
148            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
149            step.aggregations = list(aggregations)
150
151        for e in expression.expressions:
152            if e.find(exp.AggFunc):
153                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
154                extract_agg_operands(e)
155            else:
156                projections.append(e)
157
158        where = expression.args.get("where")
159
160        if where:
161            step.condition = where.this
162
163        group = expression.args.get("group")
164
165        if group or aggregations:
166            aggregate = Aggregate()
167            aggregate.source = step.name
168            aggregate.name = step.name
169
170            having = expression.args.get("having")
171
172            if having:
173                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
174                    aggregate.condition = exp.column("_h", step.name, quoted=True)
175                else:
176                    aggregate.condition = having.this
177
178            set_ops_and_aggs(aggregate)
179
180            # give aggregates names and replace projections with references to them
181            aggregate.group = {
182                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
183            }
184
185            intermediate: t.Dict[str | exp.Expression, str] = {}
186            for k, v in aggregate.group.items():
187                intermediate[v] = k
188                if isinstance(v, exp.Column):
189                    intermediate[v.name] = k
190
191            for projection in projections:
192                for node in projection.walk():
193                    name = intermediate.get(node)
194                    if name:
195                        node.replace(exp.column(name, step.name))
196
197            if aggregate.condition:
198                for node in aggregate.condition.walk():
199                    name = intermediate.get(node) or intermediate.get(node.name)
200                    if name:
201                        node.replace(exp.column(name, step.name))
202
203            aggregate.add_dependency(step)
204            step = aggregate
205        else:
206            aggregate = None
207
208        order = expression.args.get("order")
209
210        if order:
211            if aggregate and isinstance(step, Aggregate):
212                for i, ordered in enumerate(order.expressions):
213                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
214                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
215
216                set_ops_and_aggs(aggregate)
217
218            sort = Sort()
219            sort.name = step.name
220            sort.key = order.expressions
221            sort.add_dependency(step)
222            step = sort
223
224        step.projections = projections
225
226        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
227            distinct = Aggregate()
228            distinct.source = step.name
229            distinct.name = step.name
230            distinct.group = {
231                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
232                for e in projections or expression.expressions
233            }
234            distinct.add_dependency(step)
235            step = distinct
236
237        limit = expression.args.get("limit")
238
239        if limit:
240            step.limit = int(limit.text("expression"))
241
242        return step
243
244    def __init__(self) -> None:
245        self.name: t.Optional[str] = None
246        self.dependencies: t.Set[Step] = set()
247        self.dependents: t.Set[Step] = set()
248        self.projections: t.Sequence[exp.Expression] = []
249        self.limit: float = math.inf
250        self.condition: t.Optional[exp.Expression] = None
251
252    def add_dependency(self, dependency: Step) -> None:
253        self.dependencies.add(dependency)
254        dependency.dependents.add(self)
255
256    def __repr__(self) -> str:
257        return self.to_s()
258
259    def to_s(self, level: int = 0) -> str:
260        indent = "  " * level
261        nested = f"{indent}    "
262
263        context = self._to_s(f"{nested}  ")
264
265        if context:
266            context = [f"{nested}Context:"] + context
267
268        lines = [
269            f"{indent}- {self.id}",
270            *context,
271            f"{nested}Projections:",
272        ]
273
274        for expression in self.projections:
275            lines.append(f"{nested}  - {expression.sql()}")
276
277        if self.condition:
278            lines.append(f"{nested}Condition: {self.condition.sql()}")
279
280        if self.limit is not math.inf:
281            lines.append(f"{nested}Limit: {self.limit}")
282
283        if self.dependencies:
284            lines.append(f"{nested}Dependencies:")
285            for dependency in self.dependencies:
286                lines.append("  " + dependency.to_s(level + 1))
287
288        return "\n".join(lines)
289
290    @property
291    def type_name(self) -> str:
292        return self.__class__.__name__
293
294    @property
295    def id(self) -> str:
296        name = self.name
297        name = f" {name}" if name else ""
298        return f"{self.type_name}:{name} ({id(self)})"
299
300    def _to_s(self, _indent: str) -> t.List[str]:
301        return []
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, Step]] = None) -> Step:
 45    @classmethod
 46    def from_expression(
 47        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 48    ) -> Step:
 49        """
 50        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 51        Note: the expression's tables and subqueries must be aliased for this method to work. For
 52        example, given the following expression:
 53
 54        SELECT
 55          x.a,
 56          SUM(x.b)
 57        FROM x AS x
 58        JOIN y AS y
 59          ON x.a = y.a
 60        GROUP BY x.a
 61
 62        the following DAG is produced (the expression IDs might differ per execution):
 63
 64        - Aggregate: x (4347984624)
 65            Context:
 66              Aggregations:
 67                - SUM(x.b)
 68              Group:
 69                - x.a
 70            Projections:
 71              - x.a
 72              - "x".""
 73            Dependencies:
 74            - Join: x (4347985296)
 75              Context:
 76                y:
 77                On: x.a = y.a
 78              Projections:
 79              Dependencies:
 80              - Scan: x (4347983136)
 81                Context:
 82                  Source: x AS x
 83                Projections:
 84              - Scan: y (4343416624)
 85                Context:
 86                  Source: y AS y
 87                Projections:
 88
 89        Args:
 90            expression: the expression to build the DAG from.
 91            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 92
 93        Returns:
 94            A Step DAG corresponding to `expression`.
 95        """
 96        ctes = ctes or {}
 97        expression = expression.unnest()
 98        with_ = expression.args.get("with")
 99
100        # CTEs break the mold of scope and introduce themselves to all in the context.
101        if with_:
102            ctes = ctes.copy()
103            for cte in with_.expressions:
104                step = Step.from_expression(cte.this, ctes)
105                step.name = cte.alias
106                ctes[step.name] = step  # type: ignore
107
108        from_ = expression.args.get("from")
109
110        if isinstance(expression, exp.Select) and from_:
111            step = Scan.from_expression(from_.this, ctes)
112        elif isinstance(expression, exp.SetOperation):
113            step = SetOperation.from_expression(expression, ctes)
114        else:
115            step = Scan()
116
117        joins = expression.args.get("joins")
118
119        if joins:
120            join = Join.from_joins(joins, ctes)
121            join.name = step.name
122            join.source_name = step.name
123            join.add_dependency(step)
124            step = join
125
126        projections = []  # final selects in this chain of steps representing a select
127        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
128        aggregations = {}
129        next_operand_name = name_sequence("_a_")
130
131        def extract_agg_operands(expression):
132            agg_funcs = tuple(expression.find_all(exp.AggFunc))
133            if agg_funcs:
134                aggregations[expression] = None
135
136            for agg in agg_funcs:
137                for operand in agg.unnest_operands():
138                    if isinstance(operand, exp.Column):
139                        continue
140                    if operand not in operands:
141                        operands[operand] = next_operand_name()
142
143                    operand.replace(exp.column(operands[operand], quoted=True))
144
145            return bool(agg_funcs)
146
147        def set_ops_and_aggs(step):
148            step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
149            step.aggregations = list(aggregations)
150
151        for e in expression.expressions:
152            if e.find(exp.AggFunc):
153                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
154                extract_agg_operands(e)
155            else:
156                projections.append(e)
157
158        where = expression.args.get("where")
159
160        if where:
161            step.condition = where.this
162
163        group = expression.args.get("group")
164
165        if group or aggregations:
166            aggregate = Aggregate()
167            aggregate.source = step.name
168            aggregate.name = step.name
169
170            having = expression.args.get("having")
171
172            if having:
173                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
174                    aggregate.condition = exp.column("_h", step.name, quoted=True)
175                else:
176                    aggregate.condition = having.this
177
178            set_ops_and_aggs(aggregate)
179
180            # give aggregates names and replace projections with references to them
181            aggregate.group = {
182                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
183            }
184
185            intermediate: t.Dict[str | exp.Expression, str] = {}
186            for k, v in aggregate.group.items():
187                intermediate[v] = k
188                if isinstance(v, exp.Column):
189                    intermediate[v.name] = k
190
191            for projection in projections:
192                for node in projection.walk():
193                    name = intermediate.get(node)
194                    if name:
195                        node.replace(exp.column(name, step.name))
196
197            if aggregate.condition:
198                for node in aggregate.condition.walk():
199                    name = intermediate.get(node) or intermediate.get(node.name)
200                    if name:
201                        node.replace(exp.column(name, step.name))
202
203            aggregate.add_dependency(step)
204            step = aggregate
205        else:
206            aggregate = None
207
208        order = expression.args.get("order")
209
210        if order:
211            if aggregate and isinstance(step, Aggregate):
212                for i, ordered in enumerate(order.expressions):
213                    if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
214                        ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
215
216                set_ops_and_aggs(aggregate)
217
218            sort = Sort()
219            sort.name = step.name
220            sort.key = order.expressions
221            sort.add_dependency(step)
222            step = sort
223
224        step.projections = projections
225
226        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
227            distinct = Aggregate()
228            distinct.source = step.name
229            distinct.name = step.name
230            distinct.group = {
231                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
232                for e in projections or expression.expressions
233            }
234            distinct.add_dependency(step)
235            step = distinct
236
237        limit = expression.args.get("limit")
238
239        if limit:
240            step.limit = int(limit.text("expression"))
241
242        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

name: Optional[str]
dependencies: Set[Step]
dependents: Set[Step]
projections: Sequence[sqlglot.expressions.Expression]
limit: float
condition: Optional[sqlglot.expressions.Expression]
def add_dependency(self, dependency: Step) -> None:
252    def add_dependency(self, dependency: Step) -> None:
253        self.dependencies.add(dependency)
254        dependency.dependents.add(self)
def to_s(self, level: int = 0) -> str:
259    def to_s(self, level: int = 0) -> str:
260        indent = "  " * level
261        nested = f"{indent}    "
262
263        context = self._to_s(f"{nested}  ")
264
265        if context:
266            context = [f"{nested}Context:"] + context
267
268        lines = [
269            f"{indent}- {self.id}",
270            *context,
271            f"{nested}Projections:",
272        ]
273
274        for expression in self.projections:
275            lines.append(f"{nested}  - {expression.sql()}")
276
277        if self.condition:
278            lines.append(f"{nested}Condition: {self.condition.sql()}")
279
280        if self.limit is not math.inf:
281            lines.append(f"{nested}Limit: {self.limit}")
282
283        if self.dependencies:
284            lines.append(f"{nested}Dependencies:")
285            for dependency in self.dependencies:
286                lines.append("  " + dependency.to_s(level + 1))
287
288        return "\n".join(lines)
type_name: str
290    @property
291    def type_name(self) -> str:
292        return self.__class__.__name__
id: str
294    @property
295    def id(self) -> str:
296        name = self.name
297        name = f" {name}" if name else ""
298        return f"{self.type_name}:{name} ({id(self)})"
class Scan(Step):
304class Scan(Step):
305    @classmethod
306    def from_expression(
307        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
308    ) -> Step:
309        table = expression
310        alias_ = expression.alias_or_name
311
312        if isinstance(expression, exp.Subquery):
313            table = expression.this
314            step = Step.from_expression(table, ctes)
315            step.name = alias_
316            return step
317
318        step = Scan()
319        step.name = alias_
320        step.source = expression
321        if ctes and table.name in ctes:
322            step.add_dependency(ctes[table.name])
323
324        return step
325
326    def __init__(self) -> None:
327        super().__init__()
328        self.source: t.Optional[exp.Expression] = None
329
330    def _to_s(self, indent: str) -> t.List[str]:
331        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, Step]] = None) -> Step:
305    @classmethod
306    def from_expression(
307        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
308    ) -> Step:
309        table = expression
310        alias_ = expression.alias_or_name
311
312        if isinstance(expression, exp.Subquery):
313            table = expression.this
314            step = Step.from_expression(table, ctes)
315            step.name = alias_
316            return step
317
318        step = Scan()
319        step.name = alias_
320        step.source = expression
321        if ctes and table.name in ctes:
322            step.add_dependency(ctes[table.name])
323
324        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

source: Optional[sqlglot.expressions.Expression]
class Join(Step):
334class Join(Step):
335    @classmethod
336    def from_joins(
337        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
338    ) -> Join:
339        step = Join()
340
341        for join in joins:
342            source_key, join_key, condition = join_condition(join)
343            step.joins[join.alias_or_name] = {
344                "side": join.side,  # type: ignore
345                "join_key": join_key,
346                "source_key": source_key,
347                "condition": condition,
348            }
349
350            step.add_dependency(Scan.from_expression(join.this, ctes))
351
352        return step
353
354    def __init__(self) -> None:
355        super().__init__()
356        self.source_name: t.Optional[str] = None
357        self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
358
359    def _to_s(self, indent: str) -> t.List[str]:
360        lines = [f"{indent}Source: {self.source_name or self.name}"]
361        for name, join in self.joins.items():
362            lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
363            join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
364            if join_key:
365                lines.append(f"{indent}Key: {join_key}")
366            if join.get("condition"):
367                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
368        return lines
@classmethod
def from_joins( cls, joins: Iterable[sqlglot.expressions.Join], ctes: Optional[Dict[str, Step]] = None) -> Join:
335    @classmethod
336    def from_joins(
337        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
338    ) -> Join:
339        step = Join()
340
341        for join in joins:
342            source_key, join_key, condition = join_condition(join)
343            step.joins[join.alias_or_name] = {
344                "side": join.side,  # type: ignore
345                "join_key": join_key,
346                "source_key": source_key,
347                "condition": condition,
348            }
349
350            step.add_dependency(Scan.from_expression(join.this, ctes))
351
352        return step
source_name: Optional[str]
joins: Dict[str, Dict[str, Union[List[str], sqlglot.expressions.Expression]]]
class Aggregate(Step):
371class Aggregate(Step):
372    def __init__(self) -> None:
373        super().__init__()
374        self.aggregations: t.List[exp.Expression] = []
375        self.operands: t.Tuple[exp.Expression, ...] = ()
376        self.group: t.Dict[str, exp.Expression] = {}
377        self.source: t.Optional[str] = None
378
379    def _to_s(self, indent: str) -> t.List[str]:
380        lines = [f"{indent}Aggregations:"]
381
382        for expression in self.aggregations:
383            lines.append(f"{indent}  - {expression.sql()}")
384
385        if self.group:
386            lines.append(f"{indent}Group:")
387            for expression in self.group.values():
388                lines.append(f"{indent}  - {expression.sql()}")
389        if self.condition:
390            lines.append(f"{indent}Having:")
391            lines.append(f"{indent}  - {self.condition.sql()}")
392        if self.operands:
393            lines.append(f"{indent}Operands:")
394            for expression in self.operands:
395                lines.append(f"{indent}  - {expression.sql()}")
396
397        return lines
aggregations: List[sqlglot.expressions.Expression]
operands: Tuple[sqlglot.expressions.Expression, ...]
group: Dict[str, sqlglot.expressions.Expression]
source: Optional[str]
class Sort(Step):
400class Sort(Step):
401    def __init__(self) -> None:
402        super().__init__()
403        self.key = None
404
405    def _to_s(self, indent: str) -> t.List[str]:
406        lines = [f"{indent}Key:"]
407
408        for expression in self.key:  # type: ignore
409            lines.append(f"{indent}  - {expression.sql()}")
410
411        return lines
key
class SetOperation(Step):
414class SetOperation(Step):
415    def __init__(
416        self,
417        op: t.Type[exp.Expression],
418        left: str | None,
419        right: str | None,
420        distinct: bool = False,
421    ) -> None:
422        super().__init__()
423        self.op = op
424        self.left = left
425        self.right = right
426        self.distinct = distinct
427
428    @classmethod
429    def from_expression(
430        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
431    ) -> SetOperation:
432        assert isinstance(expression, exp.SetOperation)
433
434        left = Step.from_expression(expression.left, ctes)
435        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
436        left.name = left.name or "left"
437        right = Step.from_expression(expression.right, ctes)
438        right.name = right.name or "right"
439        step = cls(
440            op=expression.__class__,
441            left=left.name,
442            right=right.name,
443            distinct=bool(expression.args.get("distinct")),
444        )
445
446        step.add_dependency(left)
447        step.add_dependency(right)
448
449        limit = expression.args.get("limit")
450
451        if limit:
452            step.limit = int(limit.text("expression"))
453
454        return step
455
456    def _to_s(self, indent: str) -> t.List[str]:
457        lines = []
458        if self.distinct:
459            lines.append(f"{indent}Distinct: {self.distinct}")
460        return lines
461
462    @property
463    def type_name(self) -> str:
464        return self.op.__name__
SetOperation( op: Type[sqlglot.expressions.Expression], left: str | None, right: str | None, distinct: bool = False)
415    def __init__(
416        self,
417        op: t.Type[exp.Expression],
418        left: str | None,
419        right: str | None,
420        distinct: bool = False,
421    ) -> None:
422        super().__init__()
423        self.op = op
424        self.left = left
425        self.right = right
426        self.distinct = distinct
op
left
right
distinct
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, Step]] = None) -> SetOperation:
428    @classmethod
429    def from_expression(
430        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
431    ) -> SetOperation:
432        assert isinstance(expression, exp.SetOperation)
433
434        left = Step.from_expression(expression.left, ctes)
435        # SELECT 1 UNION SELECT 2  <-- these subqueries don't have names
436        left.name = left.name or "left"
437        right = Step.from_expression(expression.right, ctes)
438        right.name = right.name or "right"
439        step = cls(
440            op=expression.__class__,
441            left=left.name,
442            right=right.name,
443            distinct=bool(expression.args.get("distinct")),
444        )
445
446        step.add_dependency(left)
447        step.add_dependency(right)
448
449        limit = expression.args.get("limit")
450
451        if limit:
452            step.limit = int(limit.text("expression"))
453
454        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

type_name: str
462    @property
463    def type_name(self) -> str:
464        return self.op.__name__