Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects, copy=False)
 79            .from_(expression.subquery("_t", copy=False), copy=False)
 80            .where(exp.column(row_number).eq(1), copy=False)
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 97    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 98    corresponding expression to avoid creating invalid column references.
 99    """
100    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
101        taken = set(expression.named_selects)
102        for select in expression.selects:
103            if not select.alias_or_name:
104                alias = find_new_name(taken, "_c")
105                select.replace(exp.alias_(select, alias))
106                taken.add(alias)
107
108        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
109        qualify_filters = expression.args["qualify"].pop().this
110        expression_by_alias = {
111            select.alias: select.this
112            for select in expression.selects
113            if isinstance(select, exp.Alias)
114        }
115
116        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
117        for select_candidate in qualify_filters.find_all(select_candidates):
118            if isinstance(select_candidate, exp.Window):
119                if expression_by_alias:
120                    for column in select_candidate.find_all(exp.Column):
121                        expr = expression_by_alias.get(column.name)
122                        if expr:
123                            column.replace(expr)
124
125                alias = find_new_name(expression.named_selects, "_w")
126                expression.select(exp.alias_(select_candidate, alias), copy=False)
127                column = exp.column(alias)
128
129                if isinstance(select_candidate.parent, exp.Qualify):
130                    qualify_filters = column
131                else:
132                    select_candidate.replace(column)
133            elif select_candidate.name not in expression.named_selects:
134                expression.select(select_candidate.copy(), copy=False)
135
136        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
137            qualify_filters, copy=False
138        )
139
140    return expression
141
142
143def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
144    """
145    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
146    other expressions. This transforms removes the precision from parameterized types in expressions.
147    """
148    for node in expression.find_all(exp.DataType):
149        node.set(
150            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
151        )
152
153    return expression
154
155
156def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
157    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
158    from sqlglot.optimizer.scope import find_all_in_scope
159
160    if isinstance(expression, exp.Select):
161        unnest_aliases = {
162            unnest.alias
163            for unnest in find_all_in_scope(expression, exp.Unnest)
164            if isinstance(unnest.parent, (exp.From, exp.Join))
165        }
166        if unnest_aliases:
167            for column in expression.find_all(exp.Column):
168                if column.table in unnest_aliases:
169                    column.set("table", None)
170                elif column.db in unnest_aliases:
171                    column.set("db", None)
172
173    return expression
174
175
176def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
177    """Convert cross join unnest into lateral view explode."""
178    if isinstance(expression, exp.Select):
179        for join in expression.args.get("joins") or []:
180            unnest = join.this
181
182            if isinstance(unnest, exp.Unnest):
183                alias = unnest.args.get("alias")
184                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
185
186                expression.args["joins"].remove(join)
187
188                for e, column in zip(unnest.expressions, alias.columns if alias else []):
189                    expression.append(
190                        "laterals",
191                        exp.Lateral(
192                            this=udtf(this=e),
193                            view=True,
194                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
195                        ),
196                    )
197
198    return expression
199
200
201def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
202    """Convert explode/posexplode into unnest."""
203
204    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
205        if isinstance(expression, exp.Select):
206            from sqlglot.optimizer.scope import Scope
207
208            taken_select_names = set(expression.named_selects)
209            taken_source_names = {name for name, _ in Scope(expression).references}
210
211            def new_name(names: t.Set[str], name: str) -> str:
212                name = find_new_name(names, name)
213                names.add(name)
214                return name
215
216            arrays: t.List[exp.Condition] = []
217            series_alias = new_name(taken_select_names, "pos")
218            series = exp.alias_(
219                exp.Unnest(
220                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
221                ),
222                new_name(taken_source_names, "_u"),
223                table=[series_alias],
224            )
225
226            # we use list here because expression.selects is mutated inside the loop
227            for select in list(expression.selects):
228                explode = select.find(exp.Explode)
229
230                if explode:
231                    pos_alias = ""
232                    explode_alias = ""
233
234                    if isinstance(select, exp.Alias):
235                        explode_alias = select.args["alias"]
236                        alias = select
237                    elif isinstance(select, exp.Aliases):
238                        pos_alias = select.aliases[0]
239                        explode_alias = select.aliases[1]
240                        alias = select.replace(exp.alias_(select.this, "", copy=False))
241                    else:
242                        alias = select.replace(exp.alias_(select, ""))
243                        explode = alias.find(exp.Explode)
244                        assert explode
245
246                    is_posexplode = isinstance(explode, exp.Posexplode)
247                    explode_arg = explode.this
248
249                    if isinstance(explode, exp.ExplodeOuter):
250                        bracket = explode_arg[0]
251                        bracket.set("safe", True)
252                        bracket.set("offset", True)
253                        explode_arg = exp.func(
254                            "IF",
255                            exp.func(
256                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
257                            ).eq(0),
258                            exp.array(bracket, copy=False),
259                            explode_arg,
260                        )
261
262                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
263                    if isinstance(explode_arg, exp.Column):
264                        taken_select_names.add(explode_arg.output_name)
265
266                    unnest_source_alias = new_name(taken_source_names, "_u")
267
268                    if not explode_alias:
269                        explode_alias = new_name(taken_select_names, "col")
270
271                        if is_posexplode:
272                            pos_alias = new_name(taken_select_names, "pos")
273
274                    if not pos_alias:
275                        pos_alias = new_name(taken_select_names, "pos")
276
277                    alias.set("alias", exp.to_identifier(explode_alias))
278
279                    series_table_alias = series.args["alias"].this
280                    column = exp.If(
281                        this=exp.column(series_alias, table=series_table_alias).eq(
282                            exp.column(pos_alias, table=unnest_source_alias)
283                        ),
284                        true=exp.column(explode_alias, table=unnest_source_alias),
285                    )
286
287                    explode.replace(column)
288
289                    if is_posexplode:
290                        expressions = expression.expressions
291                        expressions.insert(
292                            expressions.index(alias) + 1,
293                            exp.If(
294                                this=exp.column(series_alias, table=series_table_alias).eq(
295                                    exp.column(pos_alias, table=unnest_source_alias)
296                                ),
297                                true=exp.column(pos_alias, table=unnest_source_alias),
298                            ).as_(pos_alias),
299                        )
300                        expression.set("expressions", expressions)
301
302                    if not arrays:
303                        if expression.args.get("from"):
304                            expression.join(series, copy=False, join_type="CROSS")
305                        else:
306                            expression.from_(series, copy=False)
307
308                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
309                    arrays.append(size)
310
311                    # trino doesn't support left join unnest with on conditions
312                    # if it did, this would be much simpler
313                    expression.join(
314                        exp.alias_(
315                            exp.Unnest(
316                                expressions=[explode_arg.copy()],
317                                offset=exp.to_identifier(pos_alias),
318                            ),
319                            unnest_source_alias,
320                            table=[explode_alias],
321                        ),
322                        join_type="CROSS",
323                        copy=False,
324                    )
325
326                    if index_offset != 1:
327                        size = size - 1
328
329                    expression.where(
330                        exp.column(series_alias, table=series_table_alias)
331                        .eq(exp.column(pos_alias, table=unnest_source_alias))
332                        .or_(
333                            (exp.column(series_alias, table=series_table_alias) > size).and_(
334                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
335                            )
336                        ),
337                        copy=False,
338                    )
339
340            if arrays:
341                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
342
343                if index_offset != 1:
344                    end = end - (1 - index_offset)
345                series.expressions[0].set("end", end)
346
347        return expression
348
349    return _explode_to_unnest
350
351
352def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
353    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
354    if (
355        isinstance(expression, exp.PERCENTILES)
356        and not isinstance(expression.parent, exp.WithinGroup)
357        and expression.expression
358    ):
359        column = expression.this.pop()
360        expression.set("this", expression.expression.pop())
361        order = exp.Order(expressions=[exp.Ordered(this=column)])
362        expression = exp.WithinGroup(this=expression, expression=order)
363
364    return expression
365
366
367def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
368    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
369    if (
370        isinstance(expression, exp.WithinGroup)
371        and isinstance(expression.this, exp.PERCENTILES)
372        and isinstance(expression.expression, exp.Order)
373    ):
374        quantile = expression.this.this
375        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
376        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
377
378    return expression
379
380
381def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
382    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
383    if isinstance(expression, exp.With) and expression.recursive:
384        next_name = name_sequence("_c_")
385
386        for cte in expression.expressions:
387            if not cte.args["alias"].columns:
388                query = cte.this
389                if isinstance(query, exp.Union):
390                    query = query.this
391
392                cte.args["alias"].set(
393                    "columns",
394                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
395                )
396
397    return expression
398
399
400def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
401    """Replace 'epoch' in casts by the equivalent date literal."""
402    if (
403        isinstance(expression, (exp.Cast, exp.TryCast))
404        and expression.name.lower() == "epoch"
405        and expression.to.this in exp.DataType.TEMPORAL_TYPES
406    ):
407        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
408
409    return expression
410
411
412def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
413    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
414    if isinstance(expression, exp.Select):
415        for join in expression.args.get("joins") or []:
416            on = join.args.get("on")
417            if on and join.kind in ("SEMI", "ANTI"):
418                subquery = exp.select("1").from_(join.this).where(on)
419                exists = exp.Exists(this=subquery)
420                if join.kind == "ANTI":
421                    exists = exists.not_(copy=False)
422
423                join.pop()
424                expression.where(exists, copy=False)
425
426    return expression
427
428
429def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
430    """
431    Converts a query with a FULL OUTER join to a union of identical queries that
432    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
433    for queries that have a single FULL OUTER join.
434    """
435    if isinstance(expression, exp.Select):
436        full_outer_joins = [
437            (index, join)
438            for index, join in enumerate(expression.args.get("joins") or [])
439            if join.side == "FULL"
440        ]
441
442        if len(full_outer_joins) == 1:
443            expression_copy = expression.copy()
444            expression.set("limit", None)
445            index, full_outer_join = full_outer_joins[0]
446            full_outer_join.set("side", "left")
447            expression_copy.args["joins"][index].set("side", "right")
448            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
449
450            return exp.union(expression, expression_copy, copy=False)
451
452    return expression
453
454
455def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
456    """
457    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
458    defined at the top-level, so for example queries like:
459
460        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
461
462    are invalid in those dialects. This transformation can be used to ensure all CTEs are
463    moved to the top level so that the final SQL code is valid from a syntax standpoint.
464
465    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
466    """
467    top_level_with = expression.args.get("with")
468    for node in expression.find_all(exp.With):
469        if node.parent is expression:
470            continue
471
472        inner_with = node.pop()
473        if not top_level_with:
474            top_level_with = inner_with
475            expression.set("with", top_level_with)
476        else:
477            if inner_with.recursive:
478                top_level_with.set("recursive", True)
479
480            top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
481
482    return expression
483
484
485def ensure_bools(expression: exp.Expression) -> exp.Expression:
486    """Converts numeric values used in conditions into explicit boolean expressions."""
487    from sqlglot.optimizer.canonicalize import ensure_bools
488
489    def _ensure_bool(node: exp.Expression) -> None:
490        if (
491            node.is_number
492            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
493            or (isinstance(node, exp.Column) and not node.type)
494        ):
495            node.replace(node.neq(0))
496
497    for node in expression.walk():
498        ensure_bools(node, _ensure_bool)
499
500    return expression
501
502
503def unqualify_columns(expression: exp.Expression) -> exp.Expression:
504    for column in expression.find_all(exp.Column):
505        # We only wanna pop off the table, db, catalog args
506        for part in column.parts[:-1]:
507            part.pop()
508
509    return expression
510
511
512def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
513    assert isinstance(expression, exp.Create)
514    for constraint in expression.find_all(exp.UniqueColumnConstraint):
515        if constraint.parent:
516            constraint.parent.pop()
517
518    return expression
519
520
521def ctas_with_tmp_tables_to_create_tmp_view(
522    expression: exp.Expression,
523    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
524) -> exp.Expression:
525    assert isinstance(expression, exp.Create)
526    properties = expression.args.get("properties")
527    temporary = any(
528        isinstance(prop, exp.TemporaryProperty)
529        for prop in (properties.expressions if properties else [])
530    )
531
532    # CTAS with temp tables map to CREATE TEMPORARY VIEW
533    if expression.kind == "TABLE" and temporary:
534        if expression.expression:
535            return exp.Create(
536                kind="TEMPORARY VIEW",
537                this=expression.this,
538                expression=expression.expression,
539            )
540        return tmp_storage_provider(expression)
541
542    return expression
543
544
545def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
546    """
547    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
548    PARTITIONED BY value is an array of column names, they are transformed into a schema.
549    The corresponding columns are removed from the create statement.
550    """
551    assert isinstance(expression, exp.Create)
552    has_schema = isinstance(expression.this, exp.Schema)
553    is_partitionable = expression.kind in {"TABLE", "VIEW"}
554
555    if has_schema and is_partitionable:
556        prop = expression.find(exp.PartitionedByProperty)
557        if prop and prop.this and not isinstance(prop.this, exp.Schema):
558            schema = expression.this
559            columns = {v.name.upper() for v in prop.this.expressions}
560            partitions = [col for col in schema.expressions if col.name.upper() in columns]
561            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
562            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
563            expression.set("this", schema)
564
565    return expression
566
567
568def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
569    """
570    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
571
572    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
573    """
574    assert isinstance(expression, exp.Create)
575    prop = expression.find(exp.PartitionedByProperty)
576    if (
577        prop
578        and prop.this
579        and isinstance(prop.this, exp.Schema)
580        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
581    ):
582        prop_this = exp.Tuple(
583            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
584        )
585        schema = expression.this
586        for e in prop.this.expressions:
587            schema.append("expressions", e)
588        prop.set("this", prop_this)
589
590    return expression
591
592
593def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
594    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
595    if isinstance(expression, exp.Struct):
596        expression.set(
597            "expressions",
598            [
599                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
600                for e in expression.expressions
601            ],
602        )
603
604    return expression
605
606
607def preprocess(
608    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
609) -> t.Callable[[Generator, exp.Expression], str]:
610    """
611    Creates a new transform by chaining a sequence of transformations and converts the resulting
612    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
613    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
614
615    Args:
616        transforms: sequence of transform functions. These will be called in order.
617
618    Returns:
619        Function that can be used as a generator transform.
620    """
621
622    def _to_sql(self, expression: exp.Expression) -> str:
623        expression_type = type(expression)
624
625        expression = transforms[0](expression)
626        for transform in transforms[1:]:
627            expression = transform(expression)
628
629        _sql_handler = getattr(self, expression.key + "_sql", None)
630        if _sql_handler:
631            return _sql_handler(expression)
632
633        transforms_handler = self.TRANSFORMS.get(type(expression))
634        if transforms_handler:
635            if expression_type is type(expression):
636                if isinstance(expression, exp.Func):
637                    return self.function_fallback_sql(expression)
638
639                # Ensures we don't enter an infinite loop. This can happen when the original expression
640                # has the same type as the final expression and there's no _sql method available for it,
641                # because then it'd re-enter _to_sql.
642                raise ValueError(
643                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
644                )
645
646            return transforms_handler(self, expression)
647
648        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
649
650    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects, copy=False)
80            .from_(expression.subquery("_t", copy=False), copy=False)
81            .where(exp.column(row_number).eq(1), copy=False)
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 98    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 99    corresponding expression to avoid creating invalid column references.
100    """
101    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
102        taken = set(expression.named_selects)
103        for select in expression.selects:
104            if not select.alias_or_name:
105                alias = find_new_name(taken, "_c")
106                select.replace(exp.alias_(select, alias))
107                taken.add(alias)
108
109        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
110        qualify_filters = expression.args["qualify"].pop().this
111        expression_by_alias = {
112            select.alias: select.this
113            for select in expression.selects
114            if isinstance(select, exp.Alias)
115        }
116
117        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
118        for select_candidate in qualify_filters.find_all(select_candidates):
119            if isinstance(select_candidate, exp.Window):
120                if expression_by_alias:
121                    for column in select_candidate.find_all(exp.Column):
122                        expr = expression_by_alias.get(column.name)
123                        if expr:
124                            column.replace(expr)
125
126                alias = find_new_name(expression.named_selects, "_w")
127                expression.select(exp.alias_(select_candidate, alias), copy=False)
128                column = exp.column(alias)
129
130                if isinstance(select_candidate.parent, exp.Qualify):
131                    qualify_filters = column
132                else:
133                    select_candidate.replace(column)
134            elif select_candidate.name not in expression.named_selects:
135                expression.select(select_candidate.copy(), copy=False)
136
137        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
138            qualify_filters, copy=False
139        )
140
141    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
144def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
145    """
146    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
147    other expressions. This transforms removes the precision from parameterized types in expressions.
148    """
149    for node in expression.find_all(exp.DataType):
150        node.set(
151            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
152        )
153
154    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
157def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
158    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
159    from sqlglot.optimizer.scope import find_all_in_scope
160
161    if isinstance(expression, exp.Select):
162        unnest_aliases = {
163            unnest.alias
164            for unnest in find_all_in_scope(expression, exp.Unnest)
165            if isinstance(unnest.parent, (exp.From, exp.Join))
166        }
167        if unnest_aliases:
168            for column in expression.find_all(exp.Column):
169                if column.table in unnest_aliases:
170                    column.set("table", None)
171                elif column.db in unnest_aliases:
172                    column.set("db", None)
173
174    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
177def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
178    """Convert cross join unnest into lateral view explode."""
179    if isinstance(expression, exp.Select):
180        for join in expression.args.get("joins") or []:
181            unnest = join.this
182
183            if isinstance(unnest, exp.Unnest):
184                alias = unnest.args.get("alias")
185                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
186
187                expression.args["joins"].remove(join)
188
189                for e, column in zip(unnest.expressions, alias.columns if alias else []):
190                    expression.append(
191                        "laterals",
192                        exp.Lateral(
193                            this=udtf(this=e),
194                            view=True,
195                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
196                        ),
197                    )
198
199    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
202def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
203    """Convert explode/posexplode into unnest."""
204
205    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
206        if isinstance(expression, exp.Select):
207            from sqlglot.optimizer.scope import Scope
208
209            taken_select_names = set(expression.named_selects)
210            taken_source_names = {name for name, _ in Scope(expression).references}
211
212            def new_name(names: t.Set[str], name: str) -> str:
213                name = find_new_name(names, name)
214                names.add(name)
215                return name
216
217            arrays: t.List[exp.Condition] = []
218            series_alias = new_name(taken_select_names, "pos")
219            series = exp.alias_(
220                exp.Unnest(
221                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
222                ),
223                new_name(taken_source_names, "_u"),
224                table=[series_alias],
225            )
226
227            # we use list here because expression.selects is mutated inside the loop
228            for select in list(expression.selects):
229                explode = select.find(exp.Explode)
230
231                if explode:
232                    pos_alias = ""
233                    explode_alias = ""
234
235                    if isinstance(select, exp.Alias):
236                        explode_alias = select.args["alias"]
237                        alias = select
238                    elif isinstance(select, exp.Aliases):
239                        pos_alias = select.aliases[0]
240                        explode_alias = select.aliases[1]
241                        alias = select.replace(exp.alias_(select.this, "", copy=False))
242                    else:
243                        alias = select.replace(exp.alias_(select, ""))
244                        explode = alias.find(exp.Explode)
245                        assert explode
246
247                    is_posexplode = isinstance(explode, exp.Posexplode)
248                    explode_arg = explode.this
249
250                    if isinstance(explode, exp.ExplodeOuter):
251                        bracket = explode_arg[0]
252                        bracket.set("safe", True)
253                        bracket.set("offset", True)
254                        explode_arg = exp.func(
255                            "IF",
256                            exp.func(
257                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
258                            ).eq(0),
259                            exp.array(bracket, copy=False),
260                            explode_arg,
261                        )
262
263                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
264                    if isinstance(explode_arg, exp.Column):
265                        taken_select_names.add(explode_arg.output_name)
266
267                    unnest_source_alias = new_name(taken_source_names, "_u")
268
269                    if not explode_alias:
270                        explode_alias = new_name(taken_select_names, "col")
271
272                        if is_posexplode:
273                            pos_alias = new_name(taken_select_names, "pos")
274
275                    if not pos_alias:
276                        pos_alias = new_name(taken_select_names, "pos")
277
278                    alias.set("alias", exp.to_identifier(explode_alias))
279
280                    series_table_alias = series.args["alias"].this
281                    column = exp.If(
282                        this=exp.column(series_alias, table=series_table_alias).eq(
283                            exp.column(pos_alias, table=unnest_source_alias)
284                        ),
285                        true=exp.column(explode_alias, table=unnest_source_alias),
286                    )
287
288                    explode.replace(column)
289
290                    if is_posexplode:
291                        expressions = expression.expressions
292                        expressions.insert(
293                            expressions.index(alias) + 1,
294                            exp.If(
295                                this=exp.column(series_alias, table=series_table_alias).eq(
296                                    exp.column(pos_alias, table=unnest_source_alias)
297                                ),
298                                true=exp.column(pos_alias, table=unnest_source_alias),
299                            ).as_(pos_alias),
300                        )
301                        expression.set("expressions", expressions)
302
303                    if not arrays:
304                        if expression.args.get("from"):
305                            expression.join(series, copy=False, join_type="CROSS")
306                        else:
307                            expression.from_(series, copy=False)
308
309                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
310                    arrays.append(size)
311
312                    # trino doesn't support left join unnest with on conditions
313                    # if it did, this would be much simpler
314                    expression.join(
315                        exp.alias_(
316                            exp.Unnest(
317                                expressions=[explode_arg.copy()],
318                                offset=exp.to_identifier(pos_alias),
319                            ),
320                            unnest_source_alias,
321                            table=[explode_alias],
322                        ),
323                        join_type="CROSS",
324                        copy=False,
325                    )
326
327                    if index_offset != 1:
328                        size = size - 1
329
330                    expression.where(
331                        exp.column(series_alias, table=series_table_alias)
332                        .eq(exp.column(pos_alias, table=unnest_source_alias))
333                        .or_(
334                            (exp.column(series_alias, table=series_table_alias) > size).and_(
335                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
336                            )
337                        ),
338                        copy=False,
339                    )
340
341            if arrays:
342                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
343
344                if index_offset != 1:
345                    end = end - (1 - index_offset)
346                series.expressions[0].set("end", end)
347
348        return expression
349
350    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
353def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
354    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
355    if (
356        isinstance(expression, exp.PERCENTILES)
357        and not isinstance(expression.parent, exp.WithinGroup)
358        and expression.expression
359    ):
360        column = expression.this.pop()
361        expression.set("this", expression.expression.pop())
362        order = exp.Order(expressions=[exp.Ordered(this=column)])
363        expression = exp.WithinGroup(this=expression, expression=order)
364
365    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
368def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
369    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
370    if (
371        isinstance(expression, exp.WithinGroup)
372        and isinstance(expression.this, exp.PERCENTILES)
373        and isinstance(expression.expression, exp.Order)
374    ):
375        quantile = expression.this.this
376        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
377        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
378
379    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
382def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
383    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
384    if isinstance(expression, exp.With) and expression.recursive:
385        next_name = name_sequence("_c_")
386
387        for cte in expression.expressions:
388            if not cte.args["alias"].columns:
389                query = cte.this
390                if isinstance(query, exp.Union):
391                    query = query.this
392
393                cte.args["alias"].set(
394                    "columns",
395                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
396                )
397
398    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
401def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
402    """Replace 'epoch' in casts by the equivalent date literal."""
403    if (
404        isinstance(expression, (exp.Cast, exp.TryCast))
405        and expression.name.lower() == "epoch"
406        and expression.to.this in exp.DataType.TEMPORAL_TYPES
407    ):
408        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
409
410    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
413def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
414    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
415    if isinstance(expression, exp.Select):
416        for join in expression.args.get("joins") or []:
417            on = join.args.get("on")
418            if on and join.kind in ("SEMI", "ANTI"):
419                subquery = exp.select("1").from_(join.this).where(on)
420                exists = exp.Exists(this=subquery)
421                if join.kind == "ANTI":
422                    exists = exists.not_(copy=False)
423
424                join.pop()
425                expression.where(exists, copy=False)
426
427    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
430def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
431    """
432    Converts a query with a FULL OUTER join to a union of identical queries that
433    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
434    for queries that have a single FULL OUTER join.
435    """
436    if isinstance(expression, exp.Select):
437        full_outer_joins = [
438            (index, join)
439            for index, join in enumerate(expression.args.get("joins") or [])
440            if join.side == "FULL"
441        ]
442
443        if len(full_outer_joins) == 1:
444            expression_copy = expression.copy()
445            expression.set("limit", None)
446            index, full_outer_join = full_outer_joins[0]
447            full_outer_join.set("side", "left")
448            expression_copy.args["joins"][index].set("side", "right")
449            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
450
451            return exp.union(expression, expression_copy, copy=False)
452
453    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
456def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
457    """
458    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
459    defined at the top-level, so for example queries like:
460
461        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
462
463    are invalid in those dialects. This transformation can be used to ensure all CTEs are
464    moved to the top level so that the final SQL code is valid from a syntax standpoint.
465
466    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
467    """
468    top_level_with = expression.args.get("with")
469    for node in expression.find_all(exp.With):
470        if node.parent is expression:
471            continue
472
473        inner_with = node.pop()
474        if not top_level_with:
475            top_level_with = inner_with
476            expression.set("with", top_level_with)
477        else:
478            if inner_with.recursive:
479                top_level_with.set("recursive", True)
480
481            top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
482
483    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
486def ensure_bools(expression: exp.Expression) -> exp.Expression:
487    """Converts numeric values used in conditions into explicit boolean expressions."""
488    from sqlglot.optimizer.canonicalize import ensure_bools
489
490    def _ensure_bool(node: exp.Expression) -> None:
491        if (
492            node.is_number
493            or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
494            or (isinstance(node, exp.Column) and not node.type)
495        ):
496            node.replace(node.neq(0))
497
498    for node in expression.walk():
499        ensure_bools(node, _ensure_bool)
500
501    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
504def unqualify_columns(expression: exp.Expression) -> exp.Expression:
505    for column in expression.find_all(exp.Column):
506        # We only wanna pop off the table, db, catalog args
507        for part in column.parts[:-1]:
508            part.pop()
509
510    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
513def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
514    assert isinstance(expression, exp.Create)
515    for constraint in expression.find_all(exp.UniqueColumnConstraint):
516        if constraint.parent:
517            constraint.parent.pop()
518
519    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
522def ctas_with_tmp_tables_to_create_tmp_view(
523    expression: exp.Expression,
524    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
525) -> exp.Expression:
526    assert isinstance(expression, exp.Create)
527    properties = expression.args.get("properties")
528    temporary = any(
529        isinstance(prop, exp.TemporaryProperty)
530        for prop in (properties.expressions if properties else [])
531    )
532
533    # CTAS with temp tables map to CREATE TEMPORARY VIEW
534    if expression.kind == "TABLE" and temporary:
535        if expression.expression:
536            return exp.Create(
537                kind="TEMPORARY VIEW",
538                this=expression.this,
539                expression=expression.expression,
540            )
541        return tmp_storage_provider(expression)
542
543    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
546def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
547    """
548    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
549    PARTITIONED BY value is an array of column names, they are transformed into a schema.
550    The corresponding columns are removed from the create statement.
551    """
552    assert isinstance(expression, exp.Create)
553    has_schema = isinstance(expression.this, exp.Schema)
554    is_partitionable = expression.kind in {"TABLE", "VIEW"}
555
556    if has_schema and is_partitionable:
557        prop = expression.find(exp.PartitionedByProperty)
558        if prop and prop.this and not isinstance(prop.this, exp.Schema):
559            schema = expression.this
560            columns = {v.name.upper() for v in prop.this.expressions}
561            partitions = [col for col in schema.expressions if col.name.upper() in columns]
562            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
563            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
564            expression.set("this", schema)
565
566    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
569def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
570    """
571    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
572
573    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
574    """
575    assert isinstance(expression, exp.Create)
576    prop = expression.find(exp.PartitionedByProperty)
577    if (
578        prop
579        and prop.this
580        and isinstance(prop.this, exp.Schema)
581        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
582    ):
583        prop_this = exp.Tuple(
584            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
585        )
586        schema = expression.this
587        for e in prop.this.expressions:
588            schema.append("expressions", e)
589        prop.set("this", prop_this)
590
591    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
594def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
595    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
596    if isinstance(expression, exp.Struct):
597        expression.set(
598            "expressions",
599            [
600                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
601                for e in expression.expressions
602            ],
603        )
604
605    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
608def preprocess(
609    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
610) -> t.Callable[[Generator, exp.Expression], str]:
611    """
612    Creates a new transform by chaining a sequence of transformations and converts the resulting
613    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
614    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
615
616    Args:
617        transforms: sequence of transform functions. These will be called in order.
618
619    Returns:
620        Function that can be used as a generator transform.
621    """
622
623    def _to_sql(self, expression: exp.Expression) -> str:
624        expression_type = type(expression)
625
626        expression = transforms[0](expression)
627        for transform in transforms[1:]:
628            expression = transform(expression)
629
630        _sql_handler = getattr(self, expression.key + "_sql", None)
631        if _sql_handler:
632            return _sql_handler(expression)
633
634        transforms_handler = self.TRANSFORMS.get(type(expression))
635        if transforms_handler:
636            if expression_type is type(expression):
637                if isinstance(expression, exp.Func):
638                    return self.function_fallback_sql(expression)
639
640                # Ensures we don't enter an infinite loop. This can happen when the original expression
641                # has the same type as the final expression and there's no _sql method available for it,
642                # because then it'd re-enter _to_sql.
643                raise ValueError(
644                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
645                )
646
647            return transforms_handler(self, expression)
648
649        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
650
651    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.