Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.errors import UnsupportedError
  7from sqlglot.helper import find_new_name, name_sequence
  8
  9
 10if t.TYPE_CHECKING:
 11    from sqlglot.generator import Generator
 12
 13
 14def preprocess(
 15    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 16) -> t.Callable[[Generator, exp.Expression], str]:
 17    """
 18    Creates a new transform by chaining a sequence of transformations and converts the resulting
 19    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 20    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 21
 22    Args:
 23        transforms: sequence of transform functions. These will be called in order.
 24
 25    Returns:
 26        Function that can be used as a generator transform.
 27    """
 28
 29    def _to_sql(self, expression: exp.Expression) -> str:
 30        expression_type = type(expression)
 31
 32        try:
 33            expression = transforms[0](expression)
 34            for transform in transforms[1:]:
 35                expression = transform(expression)
 36        except UnsupportedError as unsupported_error:
 37            self.unsupported(str(unsupported_error))
 38
 39        _sql_handler = getattr(self, expression.key + "_sql", None)
 40        if _sql_handler:
 41            return _sql_handler(expression)
 42
 43        transforms_handler = self.TRANSFORMS.get(type(expression))
 44        if transforms_handler:
 45            if expression_type is type(expression):
 46                if isinstance(expression, exp.Func):
 47                    return self.function_fallback_sql(expression)
 48
 49                # Ensures we don't enter an infinite loop. This can happen when the original expression
 50                # has the same type as the final expression and there's no _sql method available for it,
 51                # because then it'd re-enter _to_sql.
 52                raise ValueError(
 53                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 54                )
 55
 56            return transforms_handler(self, expression)
 57
 58        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 59
 60    return _to_sql
 61
 62
 63def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 64    if isinstance(expression, exp.Select):
 65        count = 0
 66        recursive_ctes = []
 67
 68        for unnest in expression.find_all(exp.Unnest):
 69            if (
 70                not isinstance(unnest.parent, (exp.From, exp.Join))
 71                or len(unnest.expressions) != 1
 72                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 73            ):
 74                continue
 75
 76            generate_date_array = unnest.expressions[0]
 77            start = generate_date_array.args.get("start")
 78            end = generate_date_array.args.get("end")
 79            step = generate_date_array.args.get("step")
 80
 81            if not start or not end or not isinstance(step, exp.Interval):
 82                continue
 83
 84            alias = unnest.args.get("alias")
 85            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 86
 87            start = exp.cast(start, "date")
 88            date_add = exp.func(
 89                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 90            )
 91            cast_date_add = exp.cast(date_add, "date")
 92
 93            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 94
 95            base_query = exp.select(start.as_(column_name))
 96            recursive_query = (
 97                exp.select(cast_date_add)
 98                .from_(cte_name)
 99                .where(cast_date_add <= exp.cast(end, "date"))
100            )
101            cte_query = base_query.union(recursive_query, distinct=False)
102
103            generate_dates_query = exp.select(column_name).from_(cte_name)
104            unnest.replace(generate_dates_query.subquery(cte_name))
105
106            recursive_ctes.append(
107                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
108            )
109            count += 1
110
111        if recursive_ctes:
112            with_expression = expression.args.get("with") or exp.With()
113            with_expression.set("recursive", True)
114            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
115            expression.set("with", with_expression)
116
117    return expression
118
119
120def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
121    """Unnests GENERATE_SERIES or SEQUENCE table references."""
122    this = expression.this
123    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
124        unnest = exp.Unnest(expressions=[this])
125        if expression.alias:
126            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
127
128        return unnest
129
130    return expression
131
132
133def unalias_group(expression: exp.Expression) -> exp.Expression:
134    """
135    Replace references to select aliases in GROUP BY clauses.
136
137    Example:
138        >>> import sqlglot
139        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
140        'SELECT a AS b FROM x GROUP BY 1'
141
142    Args:
143        expression: the expression that will be transformed.
144
145    Returns:
146        The transformed expression.
147    """
148    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
149        aliased_selects = {
150            e.alias: i
151            for i, e in enumerate(expression.parent.expressions, start=1)
152            if isinstance(e, exp.Alias)
153        }
154
155        for group_by in expression.expressions:
156            if (
157                isinstance(group_by, exp.Column)
158                and not group_by.table
159                and group_by.name in aliased_selects
160            ):
161                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
162
163    return expression
164
165
166def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
167    """
168    Convert SELECT DISTINCT ON statements to a subquery with a window function.
169
170    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
171
172    Args:
173        expression: the expression that will be transformed.
174
175    Returns:
176        The transformed expression.
177    """
178    if (
179        isinstance(expression, exp.Select)
180        and expression.args.get("distinct")
181        and expression.args["distinct"].args.get("on")
182        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
183    ):
184        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
185        outer_selects = expression.selects
186        row_number = find_new_name(expression.named_selects, "_row_number")
187        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
188        order = expression.args.get("order")
189
190        if order:
191            window.set("order", order.pop())
192        else:
193            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
194
195        window = exp.alias_(window, row_number)
196        expression.select(window, copy=False)
197
198        return (
199            exp.select(*outer_selects, copy=False)
200            .from_(expression.subquery("_t", copy=False), copy=False)
201            .where(exp.column(row_number).eq(1), copy=False)
202        )
203
204    return expression
205
206
207def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
208    """
209    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
210
211    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
212    https://docs.snowflake.com/en/sql-reference/constructs/qualify
213
214    Some dialects don't support window functions in the WHERE clause, so we need to include them as
215    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
216    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
217    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
218    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
219    corresponding expression to avoid creating invalid column references.
220    """
221    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
222        taken = set(expression.named_selects)
223        for select in expression.selects:
224            if not select.alias_or_name:
225                alias = find_new_name(taken, "_c")
226                select.replace(exp.alias_(select, alias))
227                taken.add(alias)
228
229        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
230            alias_or_name = select.alias_or_name
231            identifier = select.args.get("alias") or select.this
232            if isinstance(identifier, exp.Identifier):
233                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
234            return alias_or_name
235
236        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
237        qualify_filters = expression.args["qualify"].pop().this
238        expression_by_alias = {
239            select.alias: select.this
240            for select in expression.selects
241            if isinstance(select, exp.Alias)
242        }
243
244        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
245        for select_candidate in qualify_filters.find_all(select_candidates):
246            if isinstance(select_candidate, exp.Window):
247                if expression_by_alias:
248                    for column in select_candidate.find_all(exp.Column):
249                        expr = expression_by_alias.get(column.name)
250                        if expr:
251                            column.replace(expr)
252
253                alias = find_new_name(expression.named_selects, "_w")
254                expression.select(exp.alias_(select_candidate, alias), copy=False)
255                column = exp.column(alias)
256
257                if isinstance(select_candidate.parent, exp.Qualify):
258                    qualify_filters = column
259                else:
260                    select_candidate.replace(column)
261            elif select_candidate.name not in expression.named_selects:
262                expression.select(select_candidate.copy(), copy=False)
263
264        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
265            qualify_filters, copy=False
266        )
267
268    return expression
269
270
271def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
272    """
273    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
274    other expressions. This transforms removes the precision from parameterized types in expressions.
275    """
276    for node in expression.find_all(exp.DataType):
277        node.set(
278            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
279        )
280
281    return expression
282
283
284def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
285    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
286    from sqlglot.optimizer.scope import find_all_in_scope
287
288    if isinstance(expression, exp.Select):
289        unnest_aliases = {
290            unnest.alias
291            for unnest in find_all_in_scope(expression, exp.Unnest)
292            if isinstance(unnest.parent, (exp.From, exp.Join))
293        }
294        if unnest_aliases:
295            for column in expression.find_all(exp.Column):
296                if column.table in unnest_aliases:
297                    column.set("table", None)
298                elif column.db in unnest_aliases:
299                    column.set("db", None)
300
301    return expression
302
303
304def unnest_to_explode(
305    expression: exp.Expression,
306    unnest_using_arrays_zip: bool = True,
307) -> exp.Expression:
308    """Convert cross join unnest into lateral view explode."""
309
310    def _unnest_zip_exprs(
311        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
312    ) -> t.List[exp.Expression]:
313        if has_multi_expr:
314            if not unnest_using_arrays_zip:
315                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
316
317            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
318            zip_exprs: t.List[exp.Expression] = [
319                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
320            ]
321            u.set("expressions", zip_exprs)
322            return zip_exprs
323        return unnest_exprs
324
325    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
326        if u.args.get("offset"):
327            return exp.Posexplode
328        return exp.Inline if has_multi_expr else exp.Explode
329
330    if isinstance(expression, exp.Select):
331        from_ = expression.args.get("from")
332
333        if from_ and isinstance(from_.this, exp.Unnest):
334            unnest = from_.this
335            alias = unnest.args.get("alias")
336            exprs = unnest.expressions
337            has_multi_expr = len(exprs) > 1
338            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
339
340            unnest.replace(
341                exp.Table(
342                    this=_udtf_type(unnest, has_multi_expr)(
343                        this=this,
344                        expressions=expressions,
345                    ),
346                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
347                )
348            )
349
350        for join in expression.args.get("joins") or []:
351            join_expr = join.this
352
353            is_lateral = isinstance(join_expr, exp.Lateral)
354
355            unnest = join_expr.this if is_lateral else join_expr
356
357            if isinstance(unnest, exp.Unnest):
358                if is_lateral:
359                    alias = join_expr.args.get("alias")
360                else:
361                    alias = unnest.args.get("alias")
362                exprs = unnest.expressions
363                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
364                has_multi_expr = len(exprs) > 1
365                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
366
367                expression.args["joins"].remove(join)
368
369                alias_cols = alias.columns if alias else []
370                for e, column in zip(exprs, alias_cols):
371                    expression.append(
372                        "laterals",
373                        exp.Lateral(
374                            this=_udtf_type(unnest, has_multi_expr)(this=e),
375                            view=True,
376                            alias=exp.TableAlias(
377                                this=alias.this,  # type: ignore
378                                columns=alias_cols if unnest_using_arrays_zip else [column],  # type: ignore
379                            ),
380                        ),
381                    )
382
383    return expression
384
385
386def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
387    """Convert explode/posexplode into unnest."""
388
389    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
390        if isinstance(expression, exp.Select):
391            from sqlglot.optimizer.scope import Scope
392
393            taken_select_names = set(expression.named_selects)
394            taken_source_names = {name for name, _ in Scope(expression).references}
395
396            def new_name(names: t.Set[str], name: str) -> str:
397                name = find_new_name(names, name)
398                names.add(name)
399                return name
400
401            arrays: t.List[exp.Condition] = []
402            series_alias = new_name(taken_select_names, "pos")
403            series = exp.alias_(
404                exp.Unnest(
405                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
406                ),
407                new_name(taken_source_names, "_u"),
408                table=[series_alias],
409            )
410
411            # we use list here because expression.selects is mutated inside the loop
412            for select in list(expression.selects):
413                explode = select.find(exp.Explode)
414
415                if explode:
416                    pos_alias = ""
417                    explode_alias = ""
418
419                    if isinstance(select, exp.Alias):
420                        explode_alias = select.args["alias"]
421                        alias = select
422                    elif isinstance(select, exp.Aliases):
423                        pos_alias = select.aliases[0]
424                        explode_alias = select.aliases[1]
425                        alias = select.replace(exp.alias_(select.this, "", copy=False))
426                    else:
427                        alias = select.replace(exp.alias_(select, ""))
428                        explode = alias.find(exp.Explode)
429                        assert explode
430
431                    is_posexplode = isinstance(explode, exp.Posexplode)
432                    explode_arg = explode.this
433
434                    if isinstance(explode, exp.ExplodeOuter):
435                        bracket = explode_arg[0]
436                        bracket.set("safe", True)
437                        bracket.set("offset", True)
438                        explode_arg = exp.func(
439                            "IF",
440                            exp.func(
441                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
442                            ).eq(0),
443                            exp.array(bracket, copy=False),
444                            explode_arg,
445                        )
446
447                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
448                    if isinstance(explode_arg, exp.Column):
449                        taken_select_names.add(explode_arg.output_name)
450
451                    unnest_source_alias = new_name(taken_source_names, "_u")
452
453                    if not explode_alias:
454                        explode_alias = new_name(taken_select_names, "col")
455
456                        if is_posexplode:
457                            pos_alias = new_name(taken_select_names, "pos")
458
459                    if not pos_alias:
460                        pos_alias = new_name(taken_select_names, "pos")
461
462                    alias.set("alias", exp.to_identifier(explode_alias))
463
464                    series_table_alias = series.args["alias"].this
465                    column = exp.If(
466                        this=exp.column(series_alias, table=series_table_alias).eq(
467                            exp.column(pos_alias, table=unnest_source_alias)
468                        ),
469                        true=exp.column(explode_alias, table=unnest_source_alias),
470                    )
471
472                    explode.replace(column)
473
474                    if is_posexplode:
475                        expressions = expression.expressions
476                        expressions.insert(
477                            expressions.index(alias) + 1,
478                            exp.If(
479                                this=exp.column(series_alias, table=series_table_alias).eq(
480                                    exp.column(pos_alias, table=unnest_source_alias)
481                                ),
482                                true=exp.column(pos_alias, table=unnest_source_alias),
483                            ).as_(pos_alias),
484                        )
485                        expression.set("expressions", expressions)
486
487                    if not arrays:
488                        if expression.args.get("from"):
489                            expression.join(series, copy=False, join_type="CROSS")
490                        else:
491                            expression.from_(series, copy=False)
492
493                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
494                    arrays.append(size)
495
496                    # trino doesn't support left join unnest with on conditions
497                    # if it did, this would be much simpler
498                    expression.join(
499                        exp.alias_(
500                            exp.Unnest(
501                                expressions=[explode_arg.copy()],
502                                offset=exp.to_identifier(pos_alias),
503                            ),
504                            unnest_source_alias,
505                            table=[explode_alias],
506                        ),
507                        join_type="CROSS",
508                        copy=False,
509                    )
510
511                    if index_offset != 1:
512                        size = size - 1
513
514                    expression.where(
515                        exp.column(series_alias, table=series_table_alias)
516                        .eq(exp.column(pos_alias, table=unnest_source_alias))
517                        .or_(
518                            (exp.column(series_alias, table=series_table_alias) > size).and_(
519                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
520                            )
521                        ),
522                        copy=False,
523                    )
524
525            if arrays:
526                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
527
528                if index_offset != 1:
529                    end = end - (1 - index_offset)
530                series.expressions[0].set("end", end)
531
532        return expression
533
534    return _explode_to_unnest
535
536
537def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
538    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
539    if (
540        isinstance(expression, exp.PERCENTILES)
541        and not isinstance(expression.parent, exp.WithinGroup)
542        and expression.expression
543    ):
544        column = expression.this.pop()
545        expression.set("this", expression.expression.pop())
546        order = exp.Order(expressions=[exp.Ordered(this=column)])
547        expression = exp.WithinGroup(this=expression, expression=order)
548
549    return expression
550
551
552def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
553    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
554    if (
555        isinstance(expression, exp.WithinGroup)
556        and isinstance(expression.this, exp.PERCENTILES)
557        and isinstance(expression.expression, exp.Order)
558    ):
559        quantile = expression.this.this
560        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
561        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
562
563    return expression
564
565
566def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
567    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
568    if isinstance(expression, exp.With) and expression.recursive:
569        next_name = name_sequence("_c_")
570
571        for cte in expression.expressions:
572            if not cte.args["alias"].columns:
573                query = cte.this
574                if isinstance(query, exp.SetOperation):
575                    query = query.this
576
577                cte.args["alias"].set(
578                    "columns",
579                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
580                )
581
582    return expression
583
584
585def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
586    """Replace 'epoch' in casts by the equivalent date literal."""
587    if (
588        isinstance(expression, (exp.Cast, exp.TryCast))
589        and expression.name.lower() == "epoch"
590        and expression.to.this in exp.DataType.TEMPORAL_TYPES
591    ):
592        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
593
594    return expression
595
596
597def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
598    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
599    if isinstance(expression, exp.Select):
600        for join in expression.args.get("joins") or []:
601            on = join.args.get("on")
602            if on and join.kind in ("SEMI", "ANTI"):
603                subquery = exp.select("1").from_(join.this).where(on)
604                exists = exp.Exists(this=subquery)
605                if join.kind == "ANTI":
606                    exists = exists.not_(copy=False)
607
608                join.pop()
609                expression.where(exists, copy=False)
610
611    return expression
612
613
614def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
615    """
616    Converts a query with a FULL OUTER join to a union of identical queries that
617    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
618    for queries that have a single FULL OUTER join.
619    """
620    if isinstance(expression, exp.Select):
621        full_outer_joins = [
622            (index, join)
623            for index, join in enumerate(expression.args.get("joins") or [])
624            if join.side == "FULL"
625        ]
626
627        if len(full_outer_joins) == 1:
628            expression_copy = expression.copy()
629            expression.set("limit", None)
630            index, full_outer_join = full_outer_joins[0]
631            full_outer_join.set("side", "left")
632            expression_copy.args["joins"][index].set("side", "right")
633            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
634
635            return exp.union(expression, expression_copy, copy=False)
636
637    return expression
638
639
640def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
641    """
642    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
643    defined at the top-level, so for example queries like:
644
645        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
646
647    are invalid in those dialects. This transformation can be used to ensure all CTEs are
648    moved to the top level so that the final SQL code is valid from a syntax standpoint.
649
650    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
651    """
652    top_level_with = expression.args.get("with")
653    for inner_with in expression.find_all(exp.With):
654        if inner_with.parent is expression:
655            continue
656
657        if not top_level_with:
658            top_level_with = inner_with.pop()
659            expression.set("with", top_level_with)
660        else:
661            if inner_with.recursive:
662                top_level_with.set("recursive", True)
663
664            parent_cte = inner_with.find_ancestor(exp.CTE)
665            inner_with.pop()
666
667            if parent_cte:
668                i = top_level_with.expressions.index(parent_cte)
669                top_level_with.expressions[i:i] = inner_with.expressions
670                top_level_with.set("expressions", top_level_with.expressions)
671            else:
672                top_level_with.set(
673                    "expressions", top_level_with.expressions + inner_with.expressions
674                )
675
676    return expression
677
678
679def ensure_bools(expression: exp.Expression) -> exp.Expression:
680    """Converts numeric values used in conditions into explicit boolean expressions."""
681    from sqlglot.optimizer.canonicalize import ensure_bools
682
683    def _ensure_bool(node: exp.Expression) -> None:
684        if (
685            node.is_number
686            or (
687                not isinstance(node, exp.SubqueryPredicate)
688                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
689            )
690            or (isinstance(node, exp.Column) and not node.type)
691        ):
692            node.replace(node.neq(0))
693
694    for node in expression.walk():
695        ensure_bools(node, _ensure_bool)
696
697    return expression
698
699
700def unqualify_columns(expression: exp.Expression) -> exp.Expression:
701    for column in expression.find_all(exp.Column):
702        # We only wanna pop off the table, db, catalog args
703        for part in column.parts[:-1]:
704            part.pop()
705
706    return expression
707
708
709def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
710    assert isinstance(expression, exp.Create)
711    for constraint in expression.find_all(exp.UniqueColumnConstraint):
712        if constraint.parent:
713            constraint.parent.pop()
714
715    return expression
716
717
718def ctas_with_tmp_tables_to_create_tmp_view(
719    expression: exp.Expression,
720    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
721) -> exp.Expression:
722    assert isinstance(expression, exp.Create)
723    properties = expression.args.get("properties")
724    temporary = any(
725        isinstance(prop, exp.TemporaryProperty)
726        for prop in (properties.expressions if properties else [])
727    )
728
729    # CTAS with temp tables map to CREATE TEMPORARY VIEW
730    if expression.kind == "TABLE" and temporary:
731        if expression.expression:
732            return exp.Create(
733                kind="TEMPORARY VIEW",
734                this=expression.this,
735                expression=expression.expression,
736            )
737        return tmp_storage_provider(expression)
738
739    return expression
740
741
742def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
743    """
744    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
745    PARTITIONED BY value is an array of column names, they are transformed into a schema.
746    The corresponding columns are removed from the create statement.
747    """
748    assert isinstance(expression, exp.Create)
749    has_schema = isinstance(expression.this, exp.Schema)
750    is_partitionable = expression.kind in {"TABLE", "VIEW"}
751
752    if has_schema and is_partitionable:
753        prop = expression.find(exp.PartitionedByProperty)
754        if prop and prop.this and not isinstance(prop.this, exp.Schema):
755            schema = expression.this
756            columns = {v.name.upper() for v in prop.this.expressions}
757            partitions = [col for col in schema.expressions if col.name.upper() in columns]
758            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
759            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
760            expression.set("this", schema)
761
762    return expression
763
764
765def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
766    """
767    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
768
769    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
770    """
771    assert isinstance(expression, exp.Create)
772    prop = expression.find(exp.PartitionedByProperty)
773    if (
774        prop
775        and prop.this
776        and isinstance(prop.this, exp.Schema)
777        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
778    ):
779        prop_this = exp.Tuple(
780            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
781        )
782        schema = expression.this
783        for e in prop.this.expressions:
784            schema.append("expressions", e)
785        prop.set("this", prop_this)
786
787    return expression
788
789
790def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
791    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
792    if isinstance(expression, exp.Struct):
793        expression.set(
794            "expressions",
795            [
796                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
797                for e in expression.expressions
798            ],
799        )
800
801    return expression
802
803
804def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
805    """
806    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
807    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
808
809    For example,
810        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
811        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
812
813    Args:
814        expression: The AST to remove join marks from.
815
816    Returns:
817       The AST with join marks removed.
818    """
819    from sqlglot.optimizer.scope import traverse_scope
820
821    for scope in traverse_scope(expression):
822        query = scope.expression
823
824        where = query.args.get("where")
825        joins = query.args.get("joins")
826
827        if not where or not joins:
828            continue
829
830        query_from = query.args["from"]
831
832        # These keep track of the joins to be replaced
833        new_joins: t.Dict[str, exp.Join] = {}
834        old_joins = {join.alias_or_name: join for join in joins}
835
836        for column in scope.columns:
837            if not column.args.get("join_mark"):
838                continue
839
840            predicate = column.find_ancestor(exp.Predicate, exp.Select)
841            assert isinstance(
842                predicate, exp.Binary
843            ), "Columns can only be marked with (+) when involved in a binary operation"
844
845            predicate_parent = predicate.parent
846            join_predicate = predicate.pop()
847
848            left_columns = [
849                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
850            ]
851            right_columns = [
852                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
853            ]
854
855            assert not (
856                left_columns and right_columns
857            ), "The (+) marker cannot appear in both sides of a binary predicate"
858
859            marked_column_tables = set()
860            for col in left_columns or right_columns:
861                table = col.table
862                assert table, f"Column {col} needs to be qualified with a table"
863
864                col.set("join_mark", False)
865                marked_column_tables.add(table)
866
867            assert (
868                len(marked_column_tables) == 1
869            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
870
871            join_this = old_joins.get(col.table, query_from).this
872            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
873
874            # Upsert new_join into new_joins dictionary
875            new_join_alias_or_name = new_join.alias_or_name
876            existing_join = new_joins.get(new_join_alias_or_name)
877            if existing_join:
878                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
879            else:
880                new_joins[new_join_alias_or_name] = new_join
881
882            # If the parent of the target predicate is a binary node, then it now has only one child
883            if isinstance(predicate_parent, exp.Binary):
884                if predicate_parent.left is None:
885                    predicate_parent.replace(predicate_parent.right)
886                else:
887                    predicate_parent.replace(predicate_parent.left)
888
889        if query_from.alias_or_name in new_joins:
890            only_old_joins = old_joins.keys() - new_joins.keys()
891            assert (
892                len(only_old_joins) >= 1
893            ), "Cannot determine which table to use in the new FROM clause"
894
895            new_from_name = list(only_old_joins)[0]
896            query.set("from", exp.From(this=old_joins[new_from_name].this))
897
898        query.set("joins", list(new_joins.values()))
899
900        if not where.this:
901            where.pop()
902
903    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
15def preprocess(
16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
17) -> t.Callable[[Generator, exp.Expression], str]:
18    """
19    Creates a new transform by chaining a sequence of transformations and converts the resulting
20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
22
23    Args:
24        transforms: sequence of transform functions. These will be called in order.
25
26    Returns:
27        Function that can be used as a generator transform.
28    """
29
30    def _to_sql(self, expression: exp.Expression) -> str:
31        expression_type = type(expression)
32
33        try:
34            expression = transforms[0](expression)
35            for transform in transforms[1:]:
36                expression = transform(expression)
37        except UnsupportedError as unsupported_error:
38            self.unsupported(str(unsupported_error))
39
40        _sql_handler = getattr(self, expression.key + "_sql", None)
41        if _sql_handler:
42            return _sql_handler(expression)
43
44        transforms_handler = self.TRANSFORMS.get(type(expression))
45        if transforms_handler:
46            if expression_type is type(expression):
47                if isinstance(expression, exp.Func):
48                    return self.function_fallback_sql(expression)
49
50                # Ensures we don't enter an infinite loop. This can happen when the original expression
51                # has the same type as the final expression and there's no _sql method available for it,
52                # because then it'd re-enter _to_sql.
53                raise ValueError(
54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
55                )
56
57            return transforms_handler(self, expression)
58
59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
60
61    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 65    if isinstance(expression, exp.Select):
 66        count = 0
 67        recursive_ctes = []
 68
 69        for unnest in expression.find_all(exp.Unnest):
 70            if (
 71                not isinstance(unnest.parent, (exp.From, exp.Join))
 72                or len(unnest.expressions) != 1
 73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 74            ):
 75                continue
 76
 77            generate_date_array = unnest.expressions[0]
 78            start = generate_date_array.args.get("start")
 79            end = generate_date_array.args.get("end")
 80            step = generate_date_array.args.get("step")
 81
 82            if not start or not end or not isinstance(step, exp.Interval):
 83                continue
 84
 85            alias = unnest.args.get("alias")
 86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 87
 88            start = exp.cast(start, "date")
 89            date_add = exp.func(
 90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 91            )
 92            cast_date_add = exp.cast(date_add, "date")
 93
 94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 95
 96            base_query = exp.select(start.as_(column_name))
 97            recursive_query = (
 98                exp.select(cast_date_add)
 99                .from_(cte_name)
100                .where(cast_date_add <= exp.cast(end, "date"))
101            )
102            cte_query = base_query.union(recursive_query, distinct=False)
103
104            generate_dates_query = exp.select(column_name).from_(cte_name)
105            unnest.replace(generate_dates_query.subquery(cte_name))
106
107            recursive_ctes.append(
108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109            )
110            count += 1
111
112        if recursive_ctes:
113            with_expression = expression.args.get("with") or exp.With()
114            with_expression.set("recursive", True)
115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116            expression.set("with", with_expression)
117
118    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
123    this = expression.this
124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125        unnest = exp.Unnest(expressions=[this])
126        if expression.alias:
127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
129        return unnest
130
131    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
134def unalias_group(expression: exp.Expression) -> exp.Expression:
135    """
136    Replace references to select aliases in GROUP BY clauses.
137
138    Example:
139        >>> import sqlglot
140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141        'SELECT a AS b FROM x GROUP BY 1'
142
143    Args:
144        expression: the expression that will be transformed.
145
146    Returns:
147        The transformed expression.
148    """
149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150        aliased_selects = {
151            e.alias: i
152            for i, e in enumerate(expression.parent.expressions, start=1)
153            if isinstance(e, exp.Alias)
154        }
155
156        for group_by in expression.expressions:
157            if (
158                isinstance(group_by, exp.Column)
159                and not group_by.table
160                and group_by.name in aliased_selects
161            ):
162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
164    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168    """
169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
173    Args:
174        expression: the expression that will be transformed.
175
176    Returns:
177        The transformed expression.
178    """
179    if (
180        isinstance(expression, exp.Select)
181        and expression.args.get("distinct")
182        and expression.args["distinct"].args.get("on")
183        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
184    ):
185        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
186        outer_selects = expression.selects
187        row_number = find_new_name(expression.named_selects, "_row_number")
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189        order = expression.args.get("order")
190
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number)
197        expression.select(window, copy=False)
198
199        return (
200            exp.select(*outer_selects, copy=False)
201            .from_(expression.subquery("_t", copy=False), copy=False)
202            .where(exp.column(row_number).eq(1), copy=False)
203        )
204
205    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
208def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
209    """
210    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
211
212    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
213    https://docs.snowflake.com/en/sql-reference/constructs/qualify
214
215    Some dialects don't support window functions in the WHERE clause, so we need to include them as
216    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
217    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
218    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
219    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
220    corresponding expression to avoid creating invalid column references.
221    """
222    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
223        taken = set(expression.named_selects)
224        for select in expression.selects:
225            if not select.alias_or_name:
226                alias = find_new_name(taken, "_c")
227                select.replace(exp.alias_(select, alias))
228                taken.add(alias)
229
230        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
231            alias_or_name = select.alias_or_name
232            identifier = select.args.get("alias") or select.this
233            if isinstance(identifier, exp.Identifier):
234                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
235            return alias_or_name
236
237        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
238        qualify_filters = expression.args["qualify"].pop().this
239        expression_by_alias = {
240            select.alias: select.this
241            for select in expression.selects
242            if isinstance(select, exp.Alias)
243        }
244
245        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
246        for select_candidate in qualify_filters.find_all(select_candidates):
247            if isinstance(select_candidate, exp.Window):
248                if expression_by_alias:
249                    for column in select_candidate.find_all(exp.Column):
250                        expr = expression_by_alias.get(column.name)
251                        if expr:
252                            column.replace(expr)
253
254                alias = find_new_name(expression.named_selects, "_w")
255                expression.select(exp.alias_(select_candidate, alias), copy=False)
256                column = exp.column(alias)
257
258                if isinstance(select_candidate.parent, exp.Qualify):
259                    qualify_filters = column
260                else:
261                    select_candidate.replace(column)
262            elif select_candidate.name not in expression.named_selects:
263                expression.select(select_candidate.copy(), copy=False)
264
265        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
266            qualify_filters, copy=False
267        )
268
269    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
273    """
274    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
275    other expressions. This transforms removes the precision from parameterized types in expressions.
276    """
277    for node in expression.find_all(exp.DataType):
278        node.set(
279            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
280        )
281
282    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
285def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
286    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
287    from sqlglot.optimizer.scope import find_all_in_scope
288
289    if isinstance(expression, exp.Select):
290        unnest_aliases = {
291            unnest.alias
292            for unnest in find_all_in_scope(expression, exp.Unnest)
293            if isinstance(unnest.parent, (exp.From, exp.Join))
294        }
295        if unnest_aliases:
296            for column in expression.find_all(exp.Column):
297                if column.table in unnest_aliases:
298                    column.set("table", None)
299                elif column.db in unnest_aliases:
300                    column.set("db", None)
301
302    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
305def unnest_to_explode(
306    expression: exp.Expression,
307    unnest_using_arrays_zip: bool = True,
308) -> exp.Expression:
309    """Convert cross join unnest into lateral view explode."""
310
311    def _unnest_zip_exprs(
312        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
313    ) -> t.List[exp.Expression]:
314        if has_multi_expr:
315            if not unnest_using_arrays_zip:
316                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
317
318            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
319            zip_exprs: t.List[exp.Expression] = [
320                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
321            ]
322            u.set("expressions", zip_exprs)
323            return zip_exprs
324        return unnest_exprs
325
326    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
327        if u.args.get("offset"):
328            return exp.Posexplode
329        return exp.Inline if has_multi_expr else exp.Explode
330
331    if isinstance(expression, exp.Select):
332        from_ = expression.args.get("from")
333
334        if from_ and isinstance(from_.this, exp.Unnest):
335            unnest = from_.this
336            alias = unnest.args.get("alias")
337            exprs = unnest.expressions
338            has_multi_expr = len(exprs) > 1
339            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
340
341            unnest.replace(
342                exp.Table(
343                    this=_udtf_type(unnest, has_multi_expr)(
344                        this=this,
345                        expressions=expressions,
346                    ),
347                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
348                )
349            )
350
351        for join in expression.args.get("joins") or []:
352            join_expr = join.this
353
354            is_lateral = isinstance(join_expr, exp.Lateral)
355
356            unnest = join_expr.this if is_lateral else join_expr
357
358            if isinstance(unnest, exp.Unnest):
359                if is_lateral:
360                    alias = join_expr.args.get("alias")
361                else:
362                    alias = unnest.args.get("alias")
363                exprs = unnest.expressions
364                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
365                has_multi_expr = len(exprs) > 1
366                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
367
368                expression.args["joins"].remove(join)
369
370                alias_cols = alias.columns if alias else []
371                for e, column in zip(exprs, alias_cols):
372                    expression.append(
373                        "laterals",
374                        exp.Lateral(
375                            this=_udtf_type(unnest, has_multi_expr)(this=e),
376                            view=True,
377                            alias=exp.TableAlias(
378                                this=alias.this,  # type: ignore
379                                columns=alias_cols if unnest_using_arrays_zip else [column],  # type: ignore
380                            ),
381                        ),
382                    )
383
384    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
387def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
388    """Convert explode/posexplode into unnest."""
389
390    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
391        if isinstance(expression, exp.Select):
392            from sqlglot.optimizer.scope import Scope
393
394            taken_select_names = set(expression.named_selects)
395            taken_source_names = {name for name, _ in Scope(expression).references}
396
397            def new_name(names: t.Set[str], name: str) -> str:
398                name = find_new_name(names, name)
399                names.add(name)
400                return name
401
402            arrays: t.List[exp.Condition] = []
403            series_alias = new_name(taken_select_names, "pos")
404            series = exp.alias_(
405                exp.Unnest(
406                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
407                ),
408                new_name(taken_source_names, "_u"),
409                table=[series_alias],
410            )
411
412            # we use list here because expression.selects is mutated inside the loop
413            for select in list(expression.selects):
414                explode = select.find(exp.Explode)
415
416                if explode:
417                    pos_alias = ""
418                    explode_alias = ""
419
420                    if isinstance(select, exp.Alias):
421                        explode_alias = select.args["alias"]
422                        alias = select
423                    elif isinstance(select, exp.Aliases):
424                        pos_alias = select.aliases[0]
425                        explode_alias = select.aliases[1]
426                        alias = select.replace(exp.alias_(select.this, "", copy=False))
427                    else:
428                        alias = select.replace(exp.alias_(select, ""))
429                        explode = alias.find(exp.Explode)
430                        assert explode
431
432                    is_posexplode = isinstance(explode, exp.Posexplode)
433                    explode_arg = explode.this
434
435                    if isinstance(explode, exp.ExplodeOuter):
436                        bracket = explode_arg[0]
437                        bracket.set("safe", True)
438                        bracket.set("offset", True)
439                        explode_arg = exp.func(
440                            "IF",
441                            exp.func(
442                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
443                            ).eq(0),
444                            exp.array(bracket, copy=False),
445                            explode_arg,
446                        )
447
448                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
449                    if isinstance(explode_arg, exp.Column):
450                        taken_select_names.add(explode_arg.output_name)
451
452                    unnest_source_alias = new_name(taken_source_names, "_u")
453
454                    if not explode_alias:
455                        explode_alias = new_name(taken_select_names, "col")
456
457                        if is_posexplode:
458                            pos_alias = new_name(taken_select_names, "pos")
459
460                    if not pos_alias:
461                        pos_alias = new_name(taken_select_names, "pos")
462
463                    alias.set("alias", exp.to_identifier(explode_alias))
464
465                    series_table_alias = series.args["alias"].this
466                    column = exp.If(
467                        this=exp.column(series_alias, table=series_table_alias).eq(
468                            exp.column(pos_alias, table=unnest_source_alias)
469                        ),
470                        true=exp.column(explode_alias, table=unnest_source_alias),
471                    )
472
473                    explode.replace(column)
474
475                    if is_posexplode:
476                        expressions = expression.expressions
477                        expressions.insert(
478                            expressions.index(alias) + 1,
479                            exp.If(
480                                this=exp.column(series_alias, table=series_table_alias).eq(
481                                    exp.column(pos_alias, table=unnest_source_alias)
482                                ),
483                                true=exp.column(pos_alias, table=unnest_source_alias),
484                            ).as_(pos_alias),
485                        )
486                        expression.set("expressions", expressions)
487
488                    if not arrays:
489                        if expression.args.get("from"):
490                            expression.join(series, copy=False, join_type="CROSS")
491                        else:
492                            expression.from_(series, copy=False)
493
494                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
495                    arrays.append(size)
496
497                    # trino doesn't support left join unnest with on conditions
498                    # if it did, this would be much simpler
499                    expression.join(
500                        exp.alias_(
501                            exp.Unnest(
502                                expressions=[explode_arg.copy()],
503                                offset=exp.to_identifier(pos_alias),
504                            ),
505                            unnest_source_alias,
506                            table=[explode_alias],
507                        ),
508                        join_type="CROSS",
509                        copy=False,
510                    )
511
512                    if index_offset != 1:
513                        size = size - 1
514
515                    expression.where(
516                        exp.column(series_alias, table=series_table_alias)
517                        .eq(exp.column(pos_alias, table=unnest_source_alias))
518                        .or_(
519                            (exp.column(series_alias, table=series_table_alias) > size).and_(
520                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
521                            )
522                        ),
523                        copy=False,
524                    )
525
526            if arrays:
527                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
528
529                if index_offset != 1:
530                    end = end - (1 - index_offset)
531                series.expressions[0].set("end", end)
532
533        return expression
534
535    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
538def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
539    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
540    if (
541        isinstance(expression, exp.PERCENTILES)
542        and not isinstance(expression.parent, exp.WithinGroup)
543        and expression.expression
544    ):
545        column = expression.this.pop()
546        expression.set("this", expression.expression.pop())
547        order = exp.Order(expressions=[exp.Ordered(this=column)])
548        expression = exp.WithinGroup(this=expression, expression=order)
549
550    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
553def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
554    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
555    if (
556        isinstance(expression, exp.WithinGroup)
557        and isinstance(expression.this, exp.PERCENTILES)
558        and isinstance(expression.expression, exp.Order)
559    ):
560        quantile = expression.this.this
561        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
562        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
563
564    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
567def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
568    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
569    if isinstance(expression, exp.With) and expression.recursive:
570        next_name = name_sequence("_c_")
571
572        for cte in expression.expressions:
573            if not cte.args["alias"].columns:
574                query = cte.this
575                if isinstance(query, exp.SetOperation):
576                    query = query.this
577
578                cte.args["alias"].set(
579                    "columns",
580                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
581                )
582
583    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
586def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
587    """Replace 'epoch' in casts by the equivalent date literal."""
588    if (
589        isinstance(expression, (exp.Cast, exp.TryCast))
590        and expression.name.lower() == "epoch"
591        and expression.to.this in exp.DataType.TEMPORAL_TYPES
592    ):
593        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
594
595    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
598def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
599    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
600    if isinstance(expression, exp.Select):
601        for join in expression.args.get("joins") or []:
602            on = join.args.get("on")
603            if on and join.kind in ("SEMI", "ANTI"):
604                subquery = exp.select("1").from_(join.this).where(on)
605                exists = exp.Exists(this=subquery)
606                if join.kind == "ANTI":
607                    exists = exists.not_(copy=False)
608
609                join.pop()
610                expression.where(exists, copy=False)
611
612    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
615def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
616    """
617    Converts a query with a FULL OUTER join to a union of identical queries that
618    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
619    for queries that have a single FULL OUTER join.
620    """
621    if isinstance(expression, exp.Select):
622        full_outer_joins = [
623            (index, join)
624            for index, join in enumerate(expression.args.get("joins") or [])
625            if join.side == "FULL"
626        ]
627
628        if len(full_outer_joins) == 1:
629            expression_copy = expression.copy()
630            expression.set("limit", None)
631            index, full_outer_join = full_outer_joins[0]
632            full_outer_join.set("side", "left")
633            expression_copy.args["joins"][index].set("side", "right")
634            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
635
636            return exp.union(expression, expression_copy, copy=False)
637
638    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
641def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
642    """
643    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
644    defined at the top-level, so for example queries like:
645
646        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
647
648    are invalid in those dialects. This transformation can be used to ensure all CTEs are
649    moved to the top level so that the final SQL code is valid from a syntax standpoint.
650
651    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
652    """
653    top_level_with = expression.args.get("with")
654    for inner_with in expression.find_all(exp.With):
655        if inner_with.parent is expression:
656            continue
657
658        if not top_level_with:
659            top_level_with = inner_with.pop()
660            expression.set("with", top_level_with)
661        else:
662            if inner_with.recursive:
663                top_level_with.set("recursive", True)
664
665            parent_cte = inner_with.find_ancestor(exp.CTE)
666            inner_with.pop()
667
668            if parent_cte:
669                i = top_level_with.expressions.index(parent_cte)
670                top_level_with.expressions[i:i] = inner_with.expressions
671                top_level_with.set("expressions", top_level_with.expressions)
672            else:
673                top_level_with.set(
674                    "expressions", top_level_with.expressions + inner_with.expressions
675                )
676
677    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
680def ensure_bools(expression: exp.Expression) -> exp.Expression:
681    """Converts numeric values used in conditions into explicit boolean expressions."""
682    from sqlglot.optimizer.canonicalize import ensure_bools
683
684    def _ensure_bool(node: exp.Expression) -> None:
685        if (
686            node.is_number
687            or (
688                not isinstance(node, exp.SubqueryPredicate)
689                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
690            )
691            or (isinstance(node, exp.Column) and not node.type)
692        ):
693            node.replace(node.neq(0))
694
695    for node in expression.walk():
696        ensure_bools(node, _ensure_bool)
697
698    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
701def unqualify_columns(expression: exp.Expression) -> exp.Expression:
702    for column in expression.find_all(exp.Column):
703        # We only wanna pop off the table, db, catalog args
704        for part in column.parts[:-1]:
705            part.pop()
706
707    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
710def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
711    assert isinstance(expression, exp.Create)
712    for constraint in expression.find_all(exp.UniqueColumnConstraint):
713        if constraint.parent:
714            constraint.parent.pop()
715
716    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
719def ctas_with_tmp_tables_to_create_tmp_view(
720    expression: exp.Expression,
721    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
722) -> exp.Expression:
723    assert isinstance(expression, exp.Create)
724    properties = expression.args.get("properties")
725    temporary = any(
726        isinstance(prop, exp.TemporaryProperty)
727        for prop in (properties.expressions if properties else [])
728    )
729
730    # CTAS with temp tables map to CREATE TEMPORARY VIEW
731    if expression.kind == "TABLE" and temporary:
732        if expression.expression:
733            return exp.Create(
734                kind="TEMPORARY VIEW",
735                this=expression.this,
736                expression=expression.expression,
737            )
738        return tmp_storage_provider(expression)
739
740    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
743def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
744    """
745    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
746    PARTITIONED BY value is an array of column names, they are transformed into a schema.
747    The corresponding columns are removed from the create statement.
748    """
749    assert isinstance(expression, exp.Create)
750    has_schema = isinstance(expression.this, exp.Schema)
751    is_partitionable = expression.kind in {"TABLE", "VIEW"}
752
753    if has_schema and is_partitionable:
754        prop = expression.find(exp.PartitionedByProperty)
755        if prop and prop.this and not isinstance(prop.this, exp.Schema):
756            schema = expression.this
757            columns = {v.name.upper() for v in prop.this.expressions}
758            partitions = [col for col in schema.expressions if col.name.upper() in columns]
759            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
760            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
761            expression.set("this", schema)
762
763    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
766def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
767    """
768    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
769
770    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
771    """
772    assert isinstance(expression, exp.Create)
773    prop = expression.find(exp.PartitionedByProperty)
774    if (
775        prop
776        and prop.this
777        and isinstance(prop.this, exp.Schema)
778        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
779    ):
780        prop_this = exp.Tuple(
781            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
782        )
783        schema = expression.this
784        for e in prop.this.expressions:
785            schema.append("expressions", e)
786        prop.set("this", prop_this)
787
788    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
791def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
792    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
793    if isinstance(expression, exp.Struct):
794        expression.set(
795            "expressions",
796            [
797                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
798                for e in expression.expressions
799            ],
800        )
801
802    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
805def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
806    """
807    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
808    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
809
810    For example,
811        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
812        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
813
814    Args:
815        expression: The AST to remove join marks from.
816
817    Returns:
818       The AST with join marks removed.
819    """
820    from sqlglot.optimizer.scope import traverse_scope
821
822    for scope in traverse_scope(expression):
823        query = scope.expression
824
825        where = query.args.get("where")
826        joins = query.args.get("joins")
827
828        if not where or not joins:
829            continue
830
831        query_from = query.args["from"]
832
833        # These keep track of the joins to be replaced
834        new_joins: t.Dict[str, exp.Join] = {}
835        old_joins = {join.alias_or_name: join for join in joins}
836
837        for column in scope.columns:
838            if not column.args.get("join_mark"):
839                continue
840
841            predicate = column.find_ancestor(exp.Predicate, exp.Select)
842            assert isinstance(
843                predicate, exp.Binary
844            ), "Columns can only be marked with (+) when involved in a binary operation"
845
846            predicate_parent = predicate.parent
847            join_predicate = predicate.pop()
848
849            left_columns = [
850                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
851            ]
852            right_columns = [
853                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
854            ]
855
856            assert not (
857                left_columns and right_columns
858            ), "The (+) marker cannot appear in both sides of a binary predicate"
859
860            marked_column_tables = set()
861            for col in left_columns or right_columns:
862                table = col.table
863                assert table, f"Column {col} needs to be qualified with a table"
864
865                col.set("join_mark", False)
866                marked_column_tables.add(table)
867
868            assert (
869                len(marked_column_tables) == 1
870            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
871
872            join_this = old_joins.get(col.table, query_from).this
873            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
874
875            # Upsert new_join into new_joins dictionary
876            new_join_alias_or_name = new_join.alias_or_name
877            existing_join = new_joins.get(new_join_alias_or_name)
878            if existing_join:
879                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
880            else:
881                new_joins[new_join_alias_or_name] = new_join
882
883            # If the parent of the target predicate is a binary node, then it now has only one child
884            if isinstance(predicate_parent, exp.Binary):
885                if predicate_parent.left is None:
886                    predicate_parent.replace(predicate_parent.right)
887                else:
888                    predicate_parent.replace(predicate_parent.left)
889
890        if query_from.alias_or_name in new_joins:
891            only_old_joins = old_joins.keys() - new_joins.keys()
892            assert (
893                len(only_old_joins) >= 1
894            ), "Cannot determine which table to use in the new FROM clause"
895
896            new_from_name = list(only_old_joins)[0]
897            query.set("from", exp.From(this=old_joins[new_from_name].this))
898
899        query.set("joins", list(new_joins.values()))
900
901        if not where.this:
902            where.pop()
903
904    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.