Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.errors import UnsupportedError
  7from sqlglot.helper import find_new_name, name_sequence
  8
  9
 10if t.TYPE_CHECKING:
 11    from sqlglot._typing import E
 12    from sqlglot.generator import Generator
 13
 14
 15def preprocess(
 16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 17) -> t.Callable[[Generator, exp.Expression], str]:
 18    """
 19    Creates a new transform by chaining a sequence of transformations and converts the resulting
 20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 22
 23    Args:
 24        transforms: sequence of transform functions. These will be called in order.
 25
 26    Returns:
 27        Function that can be used as a generator transform.
 28    """
 29
 30    def _to_sql(self, expression: exp.Expression) -> str:
 31        expression_type = type(expression)
 32
 33        try:
 34            expression = transforms[0](expression)
 35            for transform in transforms[1:]:
 36                expression = transform(expression)
 37        except UnsupportedError as unsupported_error:
 38            self.unsupported(str(unsupported_error))
 39
 40        _sql_handler = getattr(self, expression.key + "_sql", None)
 41        if _sql_handler:
 42            return _sql_handler(expression)
 43
 44        transforms_handler = self.TRANSFORMS.get(type(expression))
 45        if transforms_handler:
 46            if expression_type is type(expression):
 47                if isinstance(expression, exp.Func):
 48                    return self.function_fallback_sql(expression)
 49
 50                # Ensures we don't enter an infinite loop. This can happen when the original expression
 51                # has the same type as the final expression and there's no _sql method available for it,
 52                # because then it'd re-enter _to_sql.
 53                raise ValueError(
 54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 55                )
 56
 57            return transforms_handler(self, expression)
 58
 59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 60
 61    return _to_sql
 62
 63
 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 65    if isinstance(expression, exp.Select):
 66        count = 0
 67        recursive_ctes = []
 68
 69        for unnest in expression.find_all(exp.Unnest):
 70            if (
 71                not isinstance(unnest.parent, (exp.From, exp.Join))
 72                or len(unnest.expressions) != 1
 73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 74            ):
 75                continue
 76
 77            generate_date_array = unnest.expressions[0]
 78            start = generate_date_array.args.get("start")
 79            end = generate_date_array.args.get("end")
 80            step = generate_date_array.args.get("step")
 81
 82            if not start or not end or not isinstance(step, exp.Interval):
 83                continue
 84
 85            alias = unnest.args.get("alias")
 86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 87
 88            start = exp.cast(start, "date")
 89            date_add = exp.func(
 90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 91            )
 92            cast_date_add = exp.cast(date_add, "date")
 93
 94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 95
 96            base_query = exp.select(start.as_(column_name))
 97            recursive_query = (
 98                exp.select(cast_date_add)
 99                .from_(cte_name)
100                .where(cast_date_add <= exp.cast(end, "date"))
101            )
102            cte_query = base_query.union(recursive_query, distinct=False)
103
104            generate_dates_query = exp.select(column_name).from_(cte_name)
105            unnest.replace(generate_dates_query.subquery(cte_name))
106
107            recursive_ctes.append(
108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109            )
110            count += 1
111
112        if recursive_ctes:
113            with_expression = expression.args.get("with") or exp.With()
114            with_expression.set("recursive", True)
115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116            expression.set("with", with_expression)
117
118    return expression
119
120
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
123    this = expression.this
124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125        unnest = exp.Unnest(expressions=[this])
126        if expression.alias:
127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
129        return unnest
130
131    return expression
132
133
134def unalias_group(expression: exp.Expression) -> exp.Expression:
135    """
136    Replace references to select aliases in GROUP BY clauses.
137
138    Example:
139        >>> import sqlglot
140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141        'SELECT a AS b FROM x GROUP BY 1'
142
143    Args:
144        expression: the expression that will be transformed.
145
146    Returns:
147        The transformed expression.
148    """
149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150        aliased_selects = {
151            e.alias: i
152            for i, e in enumerate(expression.parent.expressions, start=1)
153            if isinstance(e, exp.Alias)
154        }
155
156        for group_by in expression.expressions:
157            if (
158                isinstance(group_by, exp.Column)
159                and not group_by.table
160                and group_by.name in aliased_selects
161            ):
162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
164    return expression
165
166
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168    """
169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
173    Args:
174        expression: the expression that will be transformed.
175
176    Returns:
177        The transformed expression.
178    """
179    if (
180        isinstance(expression, exp.Select)
181        and expression.args.get("distinct")
182        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
183    ):
184        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
185
186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
187        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
188
189        order = expression.args.get("order")
190        if order:
191            window.set("order", order.pop())
192        else:
193            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
194
195        window = exp.alias_(window, row_number_window_alias)
196        expression.select(window, copy=False)
197
198        # We add aliases to the projections so that we can safely reference them in the outer query
199        new_selects = []
200        taken_names = {row_number_window_alias}
201        for select in expression.selects[:-1]:
202            if select.is_star:
203                new_selects = [exp.Star()]
204                break
205
206            if not isinstance(select, exp.Alias):
207                alias = find_new_name(taken_names, select.output_name or "_col")
208                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
209                select = select.replace(exp.alias_(select, alias, quoted=quoted))
210
211            taken_names.add(select.output_name)
212            new_selects.append(select.args["alias"])
213
214        return (
215            exp.select(*new_selects, copy=False)
216            .from_(expression.subquery("_t", copy=False), copy=False)
217            .where(exp.column(row_number_window_alias).eq(1), copy=False)
218        )
219
220    return expression
221
222
223def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
224    """
225    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
226
227    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
228    https://docs.snowflake.com/en/sql-reference/constructs/qualify
229
230    Some dialects don't support window functions in the WHERE clause, so we need to include them as
231    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
232    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
233    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
234    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
235    corresponding expression to avoid creating invalid column references.
236    """
237    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
238        taken = set(expression.named_selects)
239        for select in expression.selects:
240            if not select.alias_or_name:
241                alias = find_new_name(taken, "_c")
242                select.replace(exp.alias_(select, alias))
243                taken.add(alias)
244
245        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
246            alias_or_name = select.alias_or_name
247            identifier = select.args.get("alias") or select.this
248            if isinstance(identifier, exp.Identifier):
249                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
250            return alias_or_name
251
252        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
253        qualify_filters = expression.args["qualify"].pop().this
254        expression_by_alias = {
255            select.alias: select.this
256            for select in expression.selects
257            if isinstance(select, exp.Alias)
258        }
259
260        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
261        for select_candidate in list(qualify_filters.find_all(select_candidates)):
262            if isinstance(select_candidate, exp.Window):
263                if expression_by_alias:
264                    for column in select_candidate.find_all(exp.Column):
265                        expr = expression_by_alias.get(column.name)
266                        if expr:
267                            column.replace(expr)
268
269                alias = find_new_name(expression.named_selects, "_w")
270                expression.select(exp.alias_(select_candidate, alias), copy=False)
271                column = exp.column(alias)
272
273                if isinstance(select_candidate.parent, exp.Qualify):
274                    qualify_filters = column
275                else:
276                    select_candidate.replace(column)
277            elif select_candidate.name not in expression.named_selects:
278                expression.select(select_candidate.copy(), copy=False)
279
280        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
281            qualify_filters, copy=False
282        )
283
284    return expression
285
286
287def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
288    """
289    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
290    other expressions. This transforms removes the precision from parameterized types in expressions.
291    """
292    for node in expression.find_all(exp.DataType):
293        node.set(
294            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
295        )
296
297    return expression
298
299
300def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
301    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
302    from sqlglot.optimizer.scope import find_all_in_scope
303
304    if isinstance(expression, exp.Select):
305        unnest_aliases = {
306            unnest.alias
307            for unnest in find_all_in_scope(expression, exp.Unnest)
308            if isinstance(unnest.parent, (exp.From, exp.Join))
309        }
310        if unnest_aliases:
311            for column in expression.find_all(exp.Column):
312                if column.table in unnest_aliases:
313                    column.set("table", None)
314                elif column.db in unnest_aliases:
315                    column.set("db", None)
316
317    return expression
318
319
320def unnest_to_explode(
321    expression: exp.Expression,
322    unnest_using_arrays_zip: bool = True,
323) -> exp.Expression:
324    """Convert cross join unnest into lateral view explode."""
325
326    def _unnest_zip_exprs(
327        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
328    ) -> t.List[exp.Expression]:
329        if has_multi_expr:
330            if not unnest_using_arrays_zip:
331                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
332
333            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
334            zip_exprs: t.List[exp.Expression] = [
335                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
336            ]
337            u.set("expressions", zip_exprs)
338            return zip_exprs
339        return unnest_exprs
340
341    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
342        if u.args.get("offset"):
343            return exp.Posexplode
344        return exp.Inline if has_multi_expr else exp.Explode
345
346    if isinstance(expression, exp.Select):
347        from_ = expression.args.get("from")
348
349        if from_ and isinstance(from_.this, exp.Unnest):
350            unnest = from_.this
351            alias = unnest.args.get("alias")
352            exprs = unnest.expressions
353            has_multi_expr = len(exprs) > 1
354            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
355
356            unnest.replace(
357                exp.Table(
358                    this=_udtf_type(unnest, has_multi_expr)(
359                        this=this,
360                        expressions=expressions,
361                    ),
362                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
363                )
364            )
365
366        joins = expression.args.get("joins") or []
367        for join in list(joins):
368            join_expr = join.this
369
370            is_lateral = isinstance(join_expr, exp.Lateral)
371
372            unnest = join_expr.this if is_lateral else join_expr
373
374            if isinstance(unnest, exp.Unnest):
375                if is_lateral:
376                    alias = join_expr.args.get("alias")
377                else:
378                    alias = unnest.args.get("alias")
379                exprs = unnest.expressions
380                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
381                has_multi_expr = len(exprs) > 1
382                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
383
384                joins.remove(join)
385
386                alias_cols = alias.columns if alias else []
387
388                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
389                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
390                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
391
392                if not has_multi_expr and len(alias_cols) not in (1, 2):
393                    raise UnsupportedError(
394                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
395                    )
396
397                for e, column in zip(exprs, alias_cols):
398                    expression.append(
399                        "laterals",
400                        exp.Lateral(
401                            this=_udtf_type(unnest, has_multi_expr)(this=e),
402                            view=True,
403                            alias=exp.TableAlias(
404                                this=alias.this,  # type: ignore
405                                columns=alias_cols,
406                            ),
407                        ),
408                    )
409
410    return expression
411
412
413def explode_projection_to_unnest(
414    index_offset: int = 0,
415) -> t.Callable[[exp.Expression], exp.Expression]:
416    """Convert explode/posexplode projections into unnests."""
417
418    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
419        if isinstance(expression, exp.Select):
420            from sqlglot.optimizer.scope import Scope
421
422            taken_select_names = set(expression.named_selects)
423            taken_source_names = {name for name, _ in Scope(expression).references}
424
425            def new_name(names: t.Set[str], name: str) -> str:
426                name = find_new_name(names, name)
427                names.add(name)
428                return name
429
430            arrays: t.List[exp.Condition] = []
431            series_alias = new_name(taken_select_names, "pos")
432            series = exp.alias_(
433                exp.Unnest(
434                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
435                ),
436                new_name(taken_source_names, "_u"),
437                table=[series_alias],
438            )
439
440            # we use list here because expression.selects is mutated inside the loop
441            for select in list(expression.selects):
442                explode = select.find(exp.Explode)
443
444                if explode:
445                    pos_alias = ""
446                    explode_alias = ""
447
448                    if isinstance(select, exp.Alias):
449                        explode_alias = select.args["alias"]
450                        alias = select
451                    elif isinstance(select, exp.Aliases):
452                        pos_alias = select.aliases[0]
453                        explode_alias = select.aliases[1]
454                        alias = select.replace(exp.alias_(select.this, "", copy=False))
455                    else:
456                        alias = select.replace(exp.alias_(select, ""))
457                        explode = alias.find(exp.Explode)
458                        assert explode
459
460                    is_posexplode = isinstance(explode, exp.Posexplode)
461                    explode_arg = explode.this
462
463                    if isinstance(explode, exp.ExplodeOuter):
464                        bracket = explode_arg[0]
465                        bracket.set("safe", True)
466                        bracket.set("offset", True)
467                        explode_arg = exp.func(
468                            "IF",
469                            exp.func(
470                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
471                            ).eq(0),
472                            exp.array(bracket, copy=False),
473                            explode_arg,
474                        )
475
476                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
477                    if isinstance(explode_arg, exp.Column):
478                        taken_select_names.add(explode_arg.output_name)
479
480                    unnest_source_alias = new_name(taken_source_names, "_u")
481
482                    if not explode_alias:
483                        explode_alias = new_name(taken_select_names, "col")
484
485                        if is_posexplode:
486                            pos_alias = new_name(taken_select_names, "pos")
487
488                    if not pos_alias:
489                        pos_alias = new_name(taken_select_names, "pos")
490
491                    alias.set("alias", exp.to_identifier(explode_alias))
492
493                    series_table_alias = series.args["alias"].this
494                    column = exp.If(
495                        this=exp.column(series_alias, table=series_table_alias).eq(
496                            exp.column(pos_alias, table=unnest_source_alias)
497                        ),
498                        true=exp.column(explode_alias, table=unnest_source_alias),
499                    )
500
501                    explode.replace(column)
502
503                    if is_posexplode:
504                        expressions = expression.expressions
505                        expressions.insert(
506                            expressions.index(alias) + 1,
507                            exp.If(
508                                this=exp.column(series_alias, table=series_table_alias).eq(
509                                    exp.column(pos_alias, table=unnest_source_alias)
510                                ),
511                                true=exp.column(pos_alias, table=unnest_source_alias),
512                            ).as_(pos_alias),
513                        )
514                        expression.set("expressions", expressions)
515
516                    if not arrays:
517                        if expression.args.get("from"):
518                            expression.join(series, copy=False, join_type="CROSS")
519                        else:
520                            expression.from_(series, copy=False)
521
522                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
523                    arrays.append(size)
524
525                    # trino doesn't support left join unnest with on conditions
526                    # if it did, this would be much simpler
527                    expression.join(
528                        exp.alias_(
529                            exp.Unnest(
530                                expressions=[explode_arg.copy()],
531                                offset=exp.to_identifier(pos_alias),
532                            ),
533                            unnest_source_alias,
534                            table=[explode_alias],
535                        ),
536                        join_type="CROSS",
537                        copy=False,
538                    )
539
540                    if index_offset != 1:
541                        size = size - 1
542
543                    expression.where(
544                        exp.column(series_alias, table=series_table_alias)
545                        .eq(exp.column(pos_alias, table=unnest_source_alias))
546                        .or_(
547                            (exp.column(series_alias, table=series_table_alias) > size).and_(
548                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
549                            )
550                        ),
551                        copy=False,
552                    )
553
554            if arrays:
555                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
556
557                if index_offset != 1:
558                    end = end - (1 - index_offset)
559                series.expressions[0].set("end", end)
560
561        return expression
562
563    return _explode_projection_to_unnest
564
565
566def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
567    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
568    if (
569        isinstance(expression, exp.PERCENTILES)
570        and not isinstance(expression.parent, exp.WithinGroup)
571        and expression.expression
572    ):
573        column = expression.this.pop()
574        expression.set("this", expression.expression.pop())
575        order = exp.Order(expressions=[exp.Ordered(this=column)])
576        expression = exp.WithinGroup(this=expression, expression=order)
577
578    return expression
579
580
581def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
582    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
583    if (
584        isinstance(expression, exp.WithinGroup)
585        and isinstance(expression.this, exp.PERCENTILES)
586        and isinstance(expression.expression, exp.Order)
587    ):
588        quantile = expression.this.this
589        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
590        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
591
592    return expression
593
594
595def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
596    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
597    if isinstance(expression, exp.With) and expression.recursive:
598        next_name = name_sequence("_c_")
599
600        for cte in expression.expressions:
601            if not cte.args["alias"].columns:
602                query = cte.this
603                if isinstance(query, exp.SetOperation):
604                    query = query.this
605
606                cte.args["alias"].set(
607                    "columns",
608                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
609                )
610
611    return expression
612
613
614def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
615    """Replace 'epoch' in casts by the equivalent date literal."""
616    if (
617        isinstance(expression, (exp.Cast, exp.TryCast))
618        and expression.name.lower() == "epoch"
619        and expression.to.this in exp.DataType.TEMPORAL_TYPES
620    ):
621        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
622
623    return expression
624
625
626def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
627    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
628    if isinstance(expression, exp.Select):
629        for join in expression.args.get("joins") or []:
630            on = join.args.get("on")
631            if on and join.kind in ("SEMI", "ANTI"):
632                subquery = exp.select("1").from_(join.this).where(on)
633                exists = exp.Exists(this=subquery)
634                if join.kind == "ANTI":
635                    exists = exists.not_(copy=False)
636
637                join.pop()
638                expression.where(exists, copy=False)
639
640    return expression
641
642
643def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
644    """
645    Converts a query with a FULL OUTER join to a union of identical queries that
646    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
647    for queries that have a single FULL OUTER join.
648    """
649    if isinstance(expression, exp.Select):
650        full_outer_joins = [
651            (index, join)
652            for index, join in enumerate(expression.args.get("joins") or [])
653            if join.side == "FULL"
654        ]
655
656        if len(full_outer_joins) == 1:
657            expression_copy = expression.copy()
658            expression.set("limit", None)
659            index, full_outer_join = full_outer_joins[0]
660
661            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
662            join_conditions = full_outer_join.args.get("on") or exp.and_(
663                *[
664                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
665                    for col in full_outer_join.args.get("using")
666                ]
667            )
668
669            full_outer_join.set("side", "left")
670            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
671            expression_copy.args["joins"][index].set("side", "right")
672            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
673            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
674            expression.args.pop("order", None)  # remove order by from LEFT side
675
676            return exp.union(expression, expression_copy, copy=False, distinct=False)
677
678    return expression
679
680
681def move_ctes_to_top_level(expression: E) -> E:
682    """
683    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
684    defined at the top-level, so for example queries like:
685
686        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
687
688    are invalid in those dialects. This transformation can be used to ensure all CTEs are
689    moved to the top level so that the final SQL code is valid from a syntax standpoint.
690
691    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
692    """
693    top_level_with = expression.args.get("with")
694    for inner_with in expression.find_all(exp.With):
695        if inner_with.parent is expression:
696            continue
697
698        if not top_level_with:
699            top_level_with = inner_with.pop()
700            expression.set("with", top_level_with)
701        else:
702            if inner_with.recursive:
703                top_level_with.set("recursive", True)
704
705            parent_cte = inner_with.find_ancestor(exp.CTE)
706            inner_with.pop()
707
708            if parent_cte:
709                i = top_level_with.expressions.index(parent_cte)
710                top_level_with.expressions[i:i] = inner_with.expressions
711                top_level_with.set("expressions", top_level_with.expressions)
712            else:
713                top_level_with.set(
714                    "expressions", top_level_with.expressions + inner_with.expressions
715                )
716
717    return expression
718
719
720def ensure_bools(expression: exp.Expression) -> exp.Expression:
721    """Converts numeric values used in conditions into explicit boolean expressions."""
722    from sqlglot.optimizer.canonicalize import ensure_bools
723
724    def _ensure_bool(node: exp.Expression) -> None:
725        if (
726            node.is_number
727            or (
728                not isinstance(node, exp.SubqueryPredicate)
729                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
730            )
731            or (isinstance(node, exp.Column) and not node.type)
732        ):
733            node.replace(node.neq(0))
734
735    for node in expression.walk():
736        ensure_bools(node, _ensure_bool)
737
738    return expression
739
740
741def unqualify_columns(expression: exp.Expression) -> exp.Expression:
742    for column in expression.find_all(exp.Column):
743        # We only wanna pop off the table, db, catalog args
744        for part in column.parts[:-1]:
745            part.pop()
746
747    return expression
748
749
750def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
751    assert isinstance(expression, exp.Create)
752    for constraint in expression.find_all(exp.UniqueColumnConstraint):
753        if constraint.parent:
754            constraint.parent.pop()
755
756    return expression
757
758
759def ctas_with_tmp_tables_to_create_tmp_view(
760    expression: exp.Expression,
761    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
762) -> exp.Expression:
763    assert isinstance(expression, exp.Create)
764    properties = expression.args.get("properties")
765    temporary = any(
766        isinstance(prop, exp.TemporaryProperty)
767        for prop in (properties.expressions if properties else [])
768    )
769
770    # CTAS with temp tables map to CREATE TEMPORARY VIEW
771    if expression.kind == "TABLE" and temporary:
772        if expression.expression:
773            return exp.Create(
774                kind="TEMPORARY VIEW",
775                this=expression.this,
776                expression=expression.expression,
777            )
778        return tmp_storage_provider(expression)
779
780    return expression
781
782
783def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
784    """
785    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
786    PARTITIONED BY value is an array of column names, they are transformed into a schema.
787    The corresponding columns are removed from the create statement.
788    """
789    assert isinstance(expression, exp.Create)
790    has_schema = isinstance(expression.this, exp.Schema)
791    is_partitionable = expression.kind in {"TABLE", "VIEW"}
792
793    if has_schema and is_partitionable:
794        prop = expression.find(exp.PartitionedByProperty)
795        if prop and prop.this and not isinstance(prop.this, exp.Schema):
796            schema = expression.this
797            columns = {v.name.upper() for v in prop.this.expressions}
798            partitions = [col for col in schema.expressions if col.name.upper() in columns]
799            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
800            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
801            expression.set("this", schema)
802
803    return expression
804
805
806def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
807    """
808    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
809
810    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
811    """
812    assert isinstance(expression, exp.Create)
813    prop = expression.find(exp.PartitionedByProperty)
814    if (
815        prop
816        and prop.this
817        and isinstance(prop.this, exp.Schema)
818        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
819    ):
820        prop_this = exp.Tuple(
821            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
822        )
823        schema = expression.this
824        for e in prop.this.expressions:
825            schema.append("expressions", e)
826        prop.set("this", prop_this)
827
828    return expression
829
830
831def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
832    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
833    if isinstance(expression, exp.Struct):
834        expression.set(
835            "expressions",
836            [
837                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
838                for e in expression.expressions
839            ],
840        )
841
842    return expression
843
844
845def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
846    """
847    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
848    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
849
850    For example,
851        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
852        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
853
854    Args:
855        expression: The AST to remove join marks from.
856
857    Returns:
858       The AST with join marks removed.
859    """
860    from sqlglot.optimizer.scope import traverse_scope
861
862    for scope in traverse_scope(expression):
863        query = scope.expression
864
865        where = query.args.get("where")
866        joins = query.args.get("joins")
867
868        if not where or not joins:
869            continue
870
871        query_from = query.args["from"]
872
873        # These keep track of the joins to be replaced
874        new_joins: t.Dict[str, exp.Join] = {}
875        old_joins = {join.alias_or_name: join for join in joins}
876
877        for column in scope.columns:
878            if not column.args.get("join_mark"):
879                continue
880
881            predicate = column.find_ancestor(exp.Predicate, exp.Select)
882            assert isinstance(
883                predicate, exp.Binary
884            ), "Columns can only be marked with (+) when involved in a binary operation"
885
886            predicate_parent = predicate.parent
887            join_predicate = predicate.pop()
888
889            left_columns = [
890                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
891            ]
892            right_columns = [
893                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
894            ]
895
896            assert not (
897                left_columns and right_columns
898            ), "The (+) marker cannot appear in both sides of a binary predicate"
899
900            marked_column_tables = set()
901            for col in left_columns or right_columns:
902                table = col.table
903                assert table, f"Column {col} needs to be qualified with a table"
904
905                col.set("join_mark", False)
906                marked_column_tables.add(table)
907
908            assert (
909                len(marked_column_tables) == 1
910            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
911
912            # Add predicate if join already copied, or add join if it is new
913            join_this = old_joins.get(col.table, query_from).this
914            existing_join = new_joins.get(join_this.alias_or_name)
915            if existing_join:
916                existing_join.set("on", exp.and_(existing_join.args["on"], join_predicate))
917            else:
918                new_joins[join_this.alias_or_name] = exp.Join(
919                    this=join_this.copy(), on=join_predicate.copy(), kind="LEFT"
920                )
921
922            # If the parent of the target predicate is a binary node, then it now has only one child
923            if isinstance(predicate_parent, exp.Binary):
924                if predicate_parent.left is None:
925                    predicate_parent.replace(predicate_parent.right)
926                else:
927                    predicate_parent.replace(predicate_parent.left)
928
929        if query_from.alias_or_name in new_joins:
930            only_old_joins = old_joins.keys() - new_joins.keys()
931            assert (
932                len(only_old_joins) >= 1
933            ), "Cannot determine which table to use in the new FROM clause"
934
935            new_from_name = list(only_old_joins)[0]
936            query.set("from", exp.From(this=old_joins[new_from_name].this))
937
938        if new_joins:
939            query.set("joins", list(new_joins.values()))
940
941        if not where.this:
942            where.pop()
943
944    return expression
945
946
947def any_to_exists(expression: exp.Expression) -> exp.Expression:
948    """
949    Transform ANY operator to Spark's EXISTS
950
951    For example,
952        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
953        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
954
955    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
956    transformation
957    """
958    if isinstance(expression, exp.Select):
959        for any in expression.find_all(exp.Any):
960            this = any.this
961            if isinstance(this, exp.Query):
962                continue
963
964            binop = any.parent
965            if isinstance(binop, exp.Binary):
966                lambda_arg = exp.to_identifier("x")
967                any.replace(lambda_arg)
968                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
969                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
970
971    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
184    ):
185        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
186
187        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189
190        order = expression.args.get("order")
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number_window_alias)
197        expression.select(window, copy=False)
198
199        # We add aliases to the projections so that we can safely reference them in the outer query
200        new_selects = []
201        taken_names = {row_number_window_alias}
202        for select in expression.selects[:-1]:
203            if select.is_star:
204                new_selects = [exp.Star()]
205                break
206
207            if not isinstance(select, exp.Alias):
208                alias = find_new_name(taken_names, select.output_name or "_col")
209                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
210                select = select.replace(exp.alias_(select, alias, quoted=quoted))
211
212            taken_names.add(select.output_name)
213            new_selects.append(select.args["alias"])
214
215        return (
216            exp.select(*new_selects, copy=False)
217            .from_(expression.subquery("_t", copy=False), copy=False)
218            .where(exp.column(row_number_window_alias).eq(1), copy=False)
219        )
220
221    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
224def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
225    """
226    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
227
228    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
229    https://docs.snowflake.com/en/sql-reference/constructs/qualify
230
231    Some dialects don't support window functions in the WHERE clause, so we need to include them as
232    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
233    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
234    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
235    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
236    corresponding expression to avoid creating invalid column references.
237    """
238    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
239        taken = set(expression.named_selects)
240        for select in expression.selects:
241            if not select.alias_or_name:
242                alias = find_new_name(taken, "_c")
243                select.replace(exp.alias_(select, alias))
244                taken.add(alias)
245
246        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
247            alias_or_name = select.alias_or_name
248            identifier = select.args.get("alias") or select.this
249            if isinstance(identifier, exp.Identifier):
250                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
251            return alias_or_name
252
253        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
254        qualify_filters = expression.args["qualify"].pop().this
255        expression_by_alias = {
256            select.alias: select.this
257            for select in expression.selects
258            if isinstance(select, exp.Alias)
259        }
260
261        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
262        for select_candidate in list(qualify_filters.find_all(select_candidates)):
263            if isinstance(select_candidate, exp.Window):
264                if expression_by_alias:
265                    for column in select_candidate.find_all(exp.Column):
266                        expr = expression_by_alias.get(column.name)
267                        if expr:
268                            column.replace(expr)
269
270                alias = find_new_name(expression.named_selects, "_w")
271                expression.select(exp.alias_(select_candidate, alias), copy=False)
272                column = exp.column(alias)
273
274                if isinstance(select_candidate.parent, exp.Qualify):
275                    qualify_filters = column
276                else:
277                    select_candidate.replace(column)
278            elif select_candidate.name not in expression.named_selects:
279                expression.select(select_candidate.copy(), copy=False)
280
281        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
282            qualify_filters, copy=False
283        )
284
285    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
288def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
289    """
290    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
291    other expressions. This transforms removes the precision from parameterized types in expressions.
292    """
293    for node in expression.find_all(exp.DataType):
294        node.set(
295            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
296        )
297
298    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
301def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
302    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
303    from sqlglot.optimizer.scope import find_all_in_scope
304
305    if isinstance(expression, exp.Select):
306        unnest_aliases = {
307            unnest.alias
308            for unnest in find_all_in_scope(expression, exp.Unnest)
309            if isinstance(unnest.parent, (exp.From, exp.Join))
310        }
311        if unnest_aliases:
312            for column in expression.find_all(exp.Column):
313                if column.table in unnest_aliases:
314                    column.set("table", None)
315                elif column.db in unnest_aliases:
316                    column.set("db", None)
317
318    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
321def unnest_to_explode(
322    expression: exp.Expression,
323    unnest_using_arrays_zip: bool = True,
324) -> exp.Expression:
325    """Convert cross join unnest into lateral view explode."""
326
327    def _unnest_zip_exprs(
328        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
329    ) -> t.List[exp.Expression]:
330        if has_multi_expr:
331            if not unnest_using_arrays_zip:
332                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
333
334            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
335            zip_exprs: t.List[exp.Expression] = [
336                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
337            ]
338            u.set("expressions", zip_exprs)
339            return zip_exprs
340        return unnest_exprs
341
342    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
343        if u.args.get("offset"):
344            return exp.Posexplode
345        return exp.Inline if has_multi_expr else exp.Explode
346
347    if isinstance(expression, exp.Select):
348        from_ = expression.args.get("from")
349
350        if from_ and isinstance(from_.this, exp.Unnest):
351            unnest = from_.this
352            alias = unnest.args.get("alias")
353            exprs = unnest.expressions
354            has_multi_expr = len(exprs) > 1
355            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
356
357            unnest.replace(
358                exp.Table(
359                    this=_udtf_type(unnest, has_multi_expr)(
360                        this=this,
361                        expressions=expressions,
362                    ),
363                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
364                )
365            )
366
367        joins = expression.args.get("joins") or []
368        for join in list(joins):
369            join_expr = join.this
370
371            is_lateral = isinstance(join_expr, exp.Lateral)
372
373            unnest = join_expr.this if is_lateral else join_expr
374
375            if isinstance(unnest, exp.Unnest):
376                if is_lateral:
377                    alias = join_expr.args.get("alias")
378                else:
379                    alias = unnest.args.get("alias")
380                exprs = unnest.expressions
381                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
382                has_multi_expr = len(exprs) > 1
383                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
384
385                joins.remove(join)
386
387                alias_cols = alias.columns if alias else []
388
389                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
390                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
391                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
392
393                if not has_multi_expr and len(alias_cols) not in (1, 2):
394                    raise UnsupportedError(
395                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
396                    )
397
398                for e, column in zip(exprs, alias_cols):
399                    expression.append(
400                        "laterals",
401                        exp.Lateral(
402                            this=_udtf_type(unnest, has_multi_expr)(this=e),
403                            view=True,
404                            alias=exp.TableAlias(
405                                this=alias.this,  # type: ignore
406                                columns=alias_cols,
407                            ),
408                        ),
409                    )
410
411    return expression

Convert cross join unnest into lateral view explode.

def explode_projection_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
414def explode_projection_to_unnest(
415    index_offset: int = 0,
416) -> t.Callable[[exp.Expression], exp.Expression]:
417    """Convert explode/posexplode projections into unnests."""
418
419    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
420        if isinstance(expression, exp.Select):
421            from sqlglot.optimizer.scope import Scope
422
423            taken_select_names = set(expression.named_selects)
424            taken_source_names = {name for name, _ in Scope(expression).references}
425
426            def new_name(names: t.Set[str], name: str) -> str:
427                name = find_new_name(names, name)
428                names.add(name)
429                return name
430
431            arrays: t.List[exp.Condition] = []
432            series_alias = new_name(taken_select_names, "pos")
433            series = exp.alias_(
434                exp.Unnest(
435                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
436                ),
437                new_name(taken_source_names, "_u"),
438                table=[series_alias],
439            )
440
441            # we use list here because expression.selects is mutated inside the loop
442            for select in list(expression.selects):
443                explode = select.find(exp.Explode)
444
445                if explode:
446                    pos_alias = ""
447                    explode_alias = ""
448
449                    if isinstance(select, exp.Alias):
450                        explode_alias = select.args["alias"]
451                        alias = select
452                    elif isinstance(select, exp.Aliases):
453                        pos_alias = select.aliases[0]
454                        explode_alias = select.aliases[1]
455                        alias = select.replace(exp.alias_(select.this, "", copy=False))
456                    else:
457                        alias = select.replace(exp.alias_(select, ""))
458                        explode = alias.find(exp.Explode)
459                        assert explode
460
461                    is_posexplode = isinstance(explode, exp.Posexplode)
462                    explode_arg = explode.this
463
464                    if isinstance(explode, exp.ExplodeOuter):
465                        bracket = explode_arg[0]
466                        bracket.set("safe", True)
467                        bracket.set("offset", True)
468                        explode_arg = exp.func(
469                            "IF",
470                            exp.func(
471                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
472                            ).eq(0),
473                            exp.array(bracket, copy=False),
474                            explode_arg,
475                        )
476
477                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
478                    if isinstance(explode_arg, exp.Column):
479                        taken_select_names.add(explode_arg.output_name)
480
481                    unnest_source_alias = new_name(taken_source_names, "_u")
482
483                    if not explode_alias:
484                        explode_alias = new_name(taken_select_names, "col")
485
486                        if is_posexplode:
487                            pos_alias = new_name(taken_select_names, "pos")
488
489                    if not pos_alias:
490                        pos_alias = new_name(taken_select_names, "pos")
491
492                    alias.set("alias", exp.to_identifier(explode_alias))
493
494                    series_table_alias = series.args["alias"].this
495                    column = exp.If(
496                        this=exp.column(series_alias, table=series_table_alias).eq(
497                            exp.column(pos_alias, table=unnest_source_alias)
498                        ),
499                        true=exp.column(explode_alias, table=unnest_source_alias),
500                    )
501
502                    explode.replace(column)
503
504                    if is_posexplode:
505                        expressions = expression.expressions
506                        expressions.insert(
507                            expressions.index(alias) + 1,
508                            exp.If(
509                                this=exp.column(series_alias, table=series_table_alias).eq(
510                                    exp.column(pos_alias, table=unnest_source_alias)
511                                ),
512                                true=exp.column(pos_alias, table=unnest_source_alias),
513                            ).as_(pos_alias),
514                        )
515                        expression.set("expressions", expressions)
516
517                    if not arrays:
518                        if expression.args.get("from"):
519                            expression.join(series, copy=False, join_type="CROSS")
520                        else:
521                            expression.from_(series, copy=False)
522
523                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
524                    arrays.append(size)
525
526                    # trino doesn't support left join unnest with on conditions
527                    # if it did, this would be much simpler
528                    expression.join(
529                        exp.alias_(
530                            exp.Unnest(
531                                expressions=[explode_arg.copy()],
532                                offset=exp.to_identifier(pos_alias),
533                            ),
534                            unnest_source_alias,
535                            table=[explode_alias],
536                        ),
537                        join_type="CROSS",
538                        copy=False,
539                    )
540
541                    if index_offset != 1:
542                        size = size - 1
543
544                    expression.where(
545                        exp.column(series_alias, table=series_table_alias)
546                        .eq(exp.column(pos_alias, table=unnest_source_alias))
547                        .or_(
548                            (exp.column(series_alias, table=series_table_alias) > size).and_(
549                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
550                            )
551                        ),
552                        copy=False,
553                    )
554
555            if arrays:
556                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
557
558                if index_offset != 1:
559                    end = end - (1 - index_offset)
560                series.expressions[0].set("end", end)
561
562        return expression
563
564    return _explode_projection_to_unnest

Convert explode/posexplode projections into unnests.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
567def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
568    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
569    if (
570        isinstance(expression, exp.PERCENTILES)
571        and not isinstance(expression.parent, exp.WithinGroup)
572        and expression.expression
573    ):
574        column = expression.this.pop()
575        expression.set("this", expression.expression.pop())
576        order = exp.Order(expressions=[exp.Ordered(this=column)])
577        expression = exp.WithinGroup(this=expression, expression=order)
578
579    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
582def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
583    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
584    if (
585        isinstance(expression, exp.WithinGroup)
586        and isinstance(expression.this, exp.PERCENTILES)
587        and isinstance(expression.expression, exp.Order)
588    ):
589        quantile = expression.this.this
590        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
591        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
592
593    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
596def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
597    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
598    if isinstance(expression, exp.With) and expression.recursive:
599        next_name = name_sequence("_c_")
600
601        for cte in expression.expressions:
602            if not cte.args["alias"].columns:
603                query = cte.this
604                if isinstance(query, exp.SetOperation):
605                    query = query.this
606
607                cte.args["alias"].set(
608                    "columns",
609                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
610                )
611
612    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
615def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
616    """Replace 'epoch' in casts by the equivalent date literal."""
617    if (
618        isinstance(expression, (exp.Cast, exp.TryCast))
619        and expression.name.lower() == "epoch"
620        and expression.to.this in exp.DataType.TEMPORAL_TYPES
621    ):
622        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
623
624    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
627def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
628    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
629    if isinstance(expression, exp.Select):
630        for join in expression.args.get("joins") or []:
631            on = join.args.get("on")
632            if on and join.kind in ("SEMI", "ANTI"):
633                subquery = exp.select("1").from_(join.this).where(on)
634                exists = exp.Exists(this=subquery)
635                if join.kind == "ANTI":
636                    exists = exists.not_(copy=False)
637
638                join.pop()
639                expression.where(exists, copy=False)
640
641    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
644def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
645    """
646    Converts a query with a FULL OUTER join to a union of identical queries that
647    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
648    for queries that have a single FULL OUTER join.
649    """
650    if isinstance(expression, exp.Select):
651        full_outer_joins = [
652            (index, join)
653            for index, join in enumerate(expression.args.get("joins") or [])
654            if join.side == "FULL"
655        ]
656
657        if len(full_outer_joins) == 1:
658            expression_copy = expression.copy()
659            expression.set("limit", None)
660            index, full_outer_join = full_outer_joins[0]
661
662            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
663            join_conditions = full_outer_join.args.get("on") or exp.and_(
664                *[
665                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
666                    for col in full_outer_join.args.get("using")
667                ]
668            )
669
670            full_outer_join.set("side", "left")
671            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
672            expression_copy.args["joins"][index].set("side", "right")
673            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
674            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
675            expression.args.pop("order", None)  # remove order by from LEFT side
676
677            return exp.union(expression, expression_copy, copy=False, distinct=False)
678
679    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
682def move_ctes_to_top_level(expression: E) -> E:
683    """
684    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
685    defined at the top-level, so for example queries like:
686
687        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
688
689    are invalid in those dialects. This transformation can be used to ensure all CTEs are
690    moved to the top level so that the final SQL code is valid from a syntax standpoint.
691
692    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
693    """
694    top_level_with = expression.args.get("with")
695    for inner_with in expression.find_all(exp.With):
696        if inner_with.parent is expression:
697            continue
698
699        if not top_level_with:
700            top_level_with = inner_with.pop()
701            expression.set("with", top_level_with)
702        else:
703            if inner_with.recursive:
704                top_level_with.set("recursive", True)
705
706            parent_cte = inner_with.find_ancestor(exp.CTE)
707            inner_with.pop()
708
709            if parent_cte:
710                i = top_level_with.expressions.index(parent_cte)
711                top_level_with.expressions[i:i] = inner_with.expressions
712                top_level_with.set("expressions", top_level_with.expressions)
713            else:
714                top_level_with.set(
715                    "expressions", top_level_with.expressions + inner_with.expressions
716                )
717
718    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
721def ensure_bools(expression: exp.Expression) -> exp.Expression:
722    """Converts numeric values used in conditions into explicit boolean expressions."""
723    from sqlglot.optimizer.canonicalize import ensure_bools
724
725    def _ensure_bool(node: exp.Expression) -> None:
726        if (
727            node.is_number
728            or (
729                not isinstance(node, exp.SubqueryPredicate)
730                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
731            )
732            or (isinstance(node, exp.Column) and not node.type)
733        ):
734            node.replace(node.neq(0))
735
736    for node in expression.walk():
737        ensure_bools(node, _ensure_bool)
738
739    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
742def unqualify_columns(expression: exp.Expression) -> exp.Expression:
743    for column in expression.find_all(exp.Column):
744        # We only wanna pop off the table, db, catalog args
745        for part in column.parts[:-1]:
746            part.pop()
747
748    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
751def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
752    assert isinstance(expression, exp.Create)
753    for constraint in expression.find_all(exp.UniqueColumnConstraint):
754        if constraint.parent:
755            constraint.parent.pop()
756
757    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
760def ctas_with_tmp_tables_to_create_tmp_view(
761    expression: exp.Expression,
762    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
763) -> exp.Expression:
764    assert isinstance(expression, exp.Create)
765    properties = expression.args.get("properties")
766    temporary = any(
767        isinstance(prop, exp.TemporaryProperty)
768        for prop in (properties.expressions if properties else [])
769    )
770
771    # CTAS with temp tables map to CREATE TEMPORARY VIEW
772    if expression.kind == "TABLE" and temporary:
773        if expression.expression:
774            return exp.Create(
775                kind="TEMPORARY VIEW",
776                this=expression.this,
777                expression=expression.expression,
778            )
779        return tmp_storage_provider(expression)
780
781    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
784def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
785    """
786    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
787    PARTITIONED BY value is an array of column names, they are transformed into a schema.
788    The corresponding columns are removed from the create statement.
789    """
790    assert isinstance(expression, exp.Create)
791    has_schema = isinstance(expression.this, exp.Schema)
792    is_partitionable = expression.kind in {"TABLE", "VIEW"}
793
794    if has_schema and is_partitionable:
795        prop = expression.find(exp.PartitionedByProperty)
796        if prop and prop.this and not isinstance(prop.this, exp.Schema):
797            schema = expression.this
798            columns = {v.name.upper() for v in prop.this.expressions}
799            partitions = [col for col in schema.expressions if col.name.upper() in columns]
800            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
801            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
802            expression.set("this", schema)
803
804    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
807def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
808    """
809    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
810
811    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
812    """
813    assert isinstance(expression, exp.Create)
814    prop = expression.find(exp.PartitionedByProperty)
815    if (
816        prop
817        and prop.this
818        and isinstance(prop.this, exp.Schema)
819        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
820    ):
821        prop_this = exp.Tuple(
822            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
823        )
824        schema = expression.this
825        for e in prop.this.expressions:
826            schema.append("expressions", e)
827        prop.set("this", prop_this)
828
829    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
832def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
833    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
834    if isinstance(expression, exp.Struct):
835        expression.set(
836            "expressions",
837            [
838                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
839                for e in expression.expressions
840            ],
841        )
842
843    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
846def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
847    """
848    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
849    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
850
851    For example,
852        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
853        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
854
855    Args:
856        expression: The AST to remove join marks from.
857
858    Returns:
859       The AST with join marks removed.
860    """
861    from sqlglot.optimizer.scope import traverse_scope
862
863    for scope in traverse_scope(expression):
864        query = scope.expression
865
866        where = query.args.get("where")
867        joins = query.args.get("joins")
868
869        if not where or not joins:
870            continue
871
872        query_from = query.args["from"]
873
874        # These keep track of the joins to be replaced
875        new_joins: t.Dict[str, exp.Join] = {}
876        old_joins = {join.alias_or_name: join for join in joins}
877
878        for column in scope.columns:
879            if not column.args.get("join_mark"):
880                continue
881
882            predicate = column.find_ancestor(exp.Predicate, exp.Select)
883            assert isinstance(
884                predicate, exp.Binary
885            ), "Columns can only be marked with (+) when involved in a binary operation"
886
887            predicate_parent = predicate.parent
888            join_predicate = predicate.pop()
889
890            left_columns = [
891                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
892            ]
893            right_columns = [
894                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
895            ]
896
897            assert not (
898                left_columns and right_columns
899            ), "The (+) marker cannot appear in both sides of a binary predicate"
900
901            marked_column_tables = set()
902            for col in left_columns or right_columns:
903                table = col.table
904                assert table, f"Column {col} needs to be qualified with a table"
905
906                col.set("join_mark", False)
907                marked_column_tables.add(table)
908
909            assert (
910                len(marked_column_tables) == 1
911            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
912
913            # Add predicate if join already copied, or add join if it is new
914            join_this = old_joins.get(col.table, query_from).this
915            existing_join = new_joins.get(join_this.alias_or_name)
916            if existing_join:
917                existing_join.set("on", exp.and_(existing_join.args["on"], join_predicate))
918            else:
919                new_joins[join_this.alias_or_name] = exp.Join(
920                    this=join_this.copy(), on=join_predicate.copy(), kind="LEFT"
921                )
922
923            # If the parent of the target predicate is a binary node, then it now has only one child
924            if isinstance(predicate_parent, exp.Binary):
925                if predicate_parent.left is None:
926                    predicate_parent.replace(predicate_parent.right)
927                else:
928                    predicate_parent.replace(predicate_parent.left)
929
930        if query_from.alias_or_name in new_joins:
931            only_old_joins = old_joins.keys() - new_joins.keys()
932            assert (
933                len(only_old_joins) >= 1
934            ), "Cannot determine which table to use in the new FROM clause"
935
936            new_from_name = list(only_old_joins)[0]
937            query.set("from", exp.From(this=old_joins[new_from_name].this))
938
939        if new_joins:
940            query.set("joins", list(new_joins.values()))
941
942        if not where.this:
943            where.pop()
944
945    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
948def any_to_exists(expression: exp.Expression) -> exp.Expression:
949    """
950    Transform ANY operator to Spark's EXISTS
951
952    For example,
953        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
954        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
955
956    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
957    transformation
958    """
959    if isinstance(expression, exp.Select):
960        for any in expression.find_all(exp.Any):
961            this = any.this
962            if isinstance(this, exp.Query):
963                continue
964
965            binop = any.parent
966            if isinstance(binop, exp.Binary):
967                lambda_arg = exp.to_identifier("x")
968                any.replace(lambda_arg)
969                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
970                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
971
972    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation