Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.errors import UnsupportedError
  7from sqlglot.helper import find_new_name, name_sequence
  8
  9
 10if t.TYPE_CHECKING:
 11    from sqlglot._typing import E
 12    from sqlglot.generator import Generator
 13
 14
 15def preprocess(
 16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 17) -> t.Callable[[Generator, exp.Expression], str]:
 18    """
 19    Creates a new transform by chaining a sequence of transformations and converts the resulting
 20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 22
 23    Args:
 24        transforms: sequence of transform functions. These will be called in order.
 25
 26    Returns:
 27        Function that can be used as a generator transform.
 28    """
 29
 30    def _to_sql(self, expression: exp.Expression) -> str:
 31        expression_type = type(expression)
 32
 33        try:
 34            expression = transforms[0](expression)
 35            for transform in transforms[1:]:
 36                expression = transform(expression)
 37        except UnsupportedError as unsupported_error:
 38            self.unsupported(str(unsupported_error))
 39
 40        _sql_handler = getattr(self, expression.key + "_sql", None)
 41        if _sql_handler:
 42            return _sql_handler(expression)
 43
 44        transforms_handler = self.TRANSFORMS.get(type(expression))
 45        if transforms_handler:
 46            if expression_type is type(expression):
 47                if isinstance(expression, exp.Func):
 48                    return self.function_fallback_sql(expression)
 49
 50                # Ensures we don't enter an infinite loop. This can happen when the original expression
 51                # has the same type as the final expression and there's no _sql method available for it,
 52                # because then it'd re-enter _to_sql.
 53                raise ValueError(
 54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 55                )
 56
 57            return transforms_handler(self, expression)
 58
 59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 60
 61    return _to_sql
 62
 63
 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 65    if isinstance(expression, exp.Select):
 66        count = 0
 67        recursive_ctes = []
 68
 69        for unnest in expression.find_all(exp.Unnest):
 70            if (
 71                not isinstance(unnest.parent, (exp.From, exp.Join))
 72                or len(unnest.expressions) != 1
 73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 74            ):
 75                continue
 76
 77            generate_date_array = unnest.expressions[0]
 78            start = generate_date_array.args.get("start")
 79            end = generate_date_array.args.get("end")
 80            step = generate_date_array.args.get("step")
 81
 82            if not start or not end or not isinstance(step, exp.Interval):
 83                continue
 84
 85            alias = unnest.args.get("alias")
 86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 87
 88            start = exp.cast(start, "date")
 89            date_add = exp.func(
 90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 91            )
 92            cast_date_add = exp.cast(date_add, "date")
 93
 94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 95
 96            base_query = exp.select(start.as_(column_name))
 97            recursive_query = (
 98                exp.select(cast_date_add)
 99                .from_(cte_name)
100                .where(cast_date_add <= exp.cast(end, "date"))
101            )
102            cte_query = base_query.union(recursive_query, distinct=False)
103
104            generate_dates_query = exp.select(column_name).from_(cte_name)
105            unnest.replace(generate_dates_query.subquery(cte_name))
106
107            recursive_ctes.append(
108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109            )
110            count += 1
111
112        if recursive_ctes:
113            with_expression = expression.args.get("with") or exp.With()
114            with_expression.set("recursive", True)
115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116            expression.set("with", with_expression)
117
118    return expression
119
120
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
123    this = expression.this
124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125        unnest = exp.Unnest(expressions=[this])
126        if expression.alias:
127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
129        return unnest
130
131    return expression
132
133
134def unalias_group(expression: exp.Expression) -> exp.Expression:
135    """
136    Replace references to select aliases in GROUP BY clauses.
137
138    Example:
139        >>> import sqlglot
140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141        'SELECT a AS b FROM x GROUP BY 1'
142
143    Args:
144        expression: the expression that will be transformed.
145
146    Returns:
147        The transformed expression.
148    """
149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150        aliased_selects = {
151            e.alias: i
152            for i, e in enumerate(expression.parent.expressions, start=1)
153            if isinstance(e, exp.Alias)
154        }
155
156        for group_by in expression.expressions:
157            if (
158                isinstance(group_by, exp.Column)
159                and not group_by.table
160                and group_by.name in aliased_selects
161            ):
162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
164    return expression
165
166
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168    """
169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
173    Args:
174        expression: the expression that will be transformed.
175
176    Returns:
177        The transformed expression.
178    """
179    if (
180        isinstance(expression, exp.Select)
181        and expression.args.get("distinct")
182        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
183    ):
184        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
185
186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
187        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
188
189        order = expression.args.get("order")
190        if order:
191            window.set("order", order.pop())
192        else:
193            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
194
195        window = exp.alias_(window, row_number_window_alias)
196        expression.select(window, copy=False)
197
198        # We add aliases to the projections so that we can safely reference them in the outer query
199        new_selects = []
200        taken_names = {row_number_window_alias}
201        for select in expression.selects[:-1]:
202            if select.is_star:
203                new_selects = [exp.Star()]
204                break
205
206            if not isinstance(select, exp.Alias):
207                alias = find_new_name(taken_names, select.output_name or "_col")
208                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
209                select = select.replace(exp.alias_(select, alias, quoted=quoted))
210
211            taken_names.add(select.output_name)
212            new_selects.append(select.args["alias"])
213
214        return (
215            exp.select(*new_selects, copy=False)
216            .from_(expression.subquery("_t", copy=False), copy=False)
217            .where(exp.column(row_number_window_alias).eq(1), copy=False)
218        )
219
220    return expression
221
222
223def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
224    """
225    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
226
227    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
228    https://docs.snowflake.com/en/sql-reference/constructs/qualify
229
230    Some dialects don't support window functions in the WHERE clause, so we need to include them as
231    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
232    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
233    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
234    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
235    corresponding expression to avoid creating invalid column references.
236    """
237    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
238        taken = set(expression.named_selects)
239        for select in expression.selects:
240            if not select.alias_or_name:
241                alias = find_new_name(taken, "_c")
242                select.replace(exp.alias_(select, alias))
243                taken.add(alias)
244
245        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
246            alias_or_name = select.alias_or_name
247            identifier = select.args.get("alias") or select.this
248            if isinstance(identifier, exp.Identifier):
249                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
250            return alias_or_name
251
252        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
253        qualify_filters = expression.args["qualify"].pop().this
254        expression_by_alias = {
255            select.alias: select.this
256            for select in expression.selects
257            if isinstance(select, exp.Alias)
258        }
259
260        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
261        for select_candidate in qualify_filters.find_all(select_candidates):
262            if isinstance(select_candidate, exp.Window):
263                if expression_by_alias:
264                    for column in select_candidate.find_all(exp.Column):
265                        expr = expression_by_alias.get(column.name)
266                        if expr:
267                            column.replace(expr)
268
269                alias = find_new_name(expression.named_selects, "_w")
270                expression.select(exp.alias_(select_candidate, alias), copy=False)
271                column = exp.column(alias)
272
273                if isinstance(select_candidate.parent, exp.Qualify):
274                    qualify_filters = column
275                else:
276                    select_candidate.replace(column)
277            elif select_candidate.name not in expression.named_selects:
278                expression.select(select_candidate.copy(), copy=False)
279
280        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
281            qualify_filters, copy=False
282        )
283
284    return expression
285
286
287def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
288    """
289    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
290    other expressions. This transforms removes the precision from parameterized types in expressions.
291    """
292    for node in expression.find_all(exp.DataType):
293        node.set(
294            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
295        )
296
297    return expression
298
299
300def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
301    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
302    from sqlglot.optimizer.scope import find_all_in_scope
303
304    if isinstance(expression, exp.Select):
305        unnest_aliases = {
306            unnest.alias
307            for unnest in find_all_in_scope(expression, exp.Unnest)
308            if isinstance(unnest.parent, (exp.From, exp.Join))
309        }
310        if unnest_aliases:
311            for column in expression.find_all(exp.Column):
312                if column.table in unnest_aliases:
313                    column.set("table", None)
314                elif column.db in unnest_aliases:
315                    column.set("db", None)
316
317    return expression
318
319
320def unnest_to_explode(
321    expression: exp.Expression,
322    unnest_using_arrays_zip: bool = True,
323) -> exp.Expression:
324    """Convert cross join unnest into lateral view explode."""
325
326    def _unnest_zip_exprs(
327        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
328    ) -> t.List[exp.Expression]:
329        if has_multi_expr:
330            if not unnest_using_arrays_zip:
331                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
332
333            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
334            zip_exprs: t.List[exp.Expression] = [
335                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
336            ]
337            u.set("expressions", zip_exprs)
338            return zip_exprs
339        return unnest_exprs
340
341    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
342        if u.args.get("offset"):
343            return exp.Posexplode
344        return exp.Inline if has_multi_expr else exp.Explode
345
346    if isinstance(expression, exp.Select):
347        from_ = expression.args.get("from")
348
349        if from_ and isinstance(from_.this, exp.Unnest):
350            unnest = from_.this
351            alias = unnest.args.get("alias")
352            exprs = unnest.expressions
353            has_multi_expr = len(exprs) > 1
354            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
355
356            unnest.replace(
357                exp.Table(
358                    this=_udtf_type(unnest, has_multi_expr)(
359                        this=this,
360                        expressions=expressions,
361                    ),
362                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
363                )
364            )
365
366        joins = expression.args.get("joins") or []
367        for join in list(joins):
368            join_expr = join.this
369
370            is_lateral = isinstance(join_expr, exp.Lateral)
371
372            unnest = join_expr.this if is_lateral else join_expr
373
374            if isinstance(unnest, exp.Unnest):
375                if is_lateral:
376                    alias = join_expr.args.get("alias")
377                else:
378                    alias = unnest.args.get("alias")
379                exprs = unnest.expressions
380                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
381                has_multi_expr = len(exprs) > 1
382                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
383
384                joins.remove(join)
385
386                alias_cols = alias.columns if alias else []
387
388                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
389                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
390                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
391
392                if not has_multi_expr and len(alias_cols) not in (1, 2):
393                    raise UnsupportedError(
394                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
395                    )
396
397                for e, column in zip(exprs, alias_cols):
398                    expression.append(
399                        "laterals",
400                        exp.Lateral(
401                            this=_udtf_type(unnest, has_multi_expr)(this=e),
402                            view=True,
403                            alias=exp.TableAlias(
404                                this=alias.this,  # type: ignore
405                                columns=alias_cols,
406                            ),
407                        ),
408                    )
409
410    return expression
411
412
413def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
414    """Convert explode/posexplode into unnest."""
415
416    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
417        if isinstance(expression, exp.Select):
418            from sqlglot.optimizer.scope import Scope
419
420            taken_select_names = set(expression.named_selects)
421            taken_source_names = {name for name, _ in Scope(expression).references}
422
423            def new_name(names: t.Set[str], name: str) -> str:
424                name = find_new_name(names, name)
425                names.add(name)
426                return name
427
428            arrays: t.List[exp.Condition] = []
429            series_alias = new_name(taken_select_names, "pos")
430            series = exp.alias_(
431                exp.Unnest(
432                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
433                ),
434                new_name(taken_source_names, "_u"),
435                table=[series_alias],
436            )
437
438            # we use list here because expression.selects is mutated inside the loop
439            for select in list(expression.selects):
440                explode = select.find(exp.Explode)
441
442                if explode:
443                    pos_alias = ""
444                    explode_alias = ""
445
446                    if isinstance(select, exp.Alias):
447                        explode_alias = select.args["alias"]
448                        alias = select
449                    elif isinstance(select, exp.Aliases):
450                        pos_alias = select.aliases[0]
451                        explode_alias = select.aliases[1]
452                        alias = select.replace(exp.alias_(select.this, "", copy=False))
453                    else:
454                        alias = select.replace(exp.alias_(select, ""))
455                        explode = alias.find(exp.Explode)
456                        assert explode
457
458                    is_posexplode = isinstance(explode, exp.Posexplode)
459                    explode_arg = explode.this
460
461                    if isinstance(explode, exp.ExplodeOuter):
462                        bracket = explode_arg[0]
463                        bracket.set("safe", True)
464                        bracket.set("offset", True)
465                        explode_arg = exp.func(
466                            "IF",
467                            exp.func(
468                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
469                            ).eq(0),
470                            exp.array(bracket, copy=False),
471                            explode_arg,
472                        )
473
474                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
475                    if isinstance(explode_arg, exp.Column):
476                        taken_select_names.add(explode_arg.output_name)
477
478                    unnest_source_alias = new_name(taken_source_names, "_u")
479
480                    if not explode_alias:
481                        explode_alias = new_name(taken_select_names, "col")
482
483                        if is_posexplode:
484                            pos_alias = new_name(taken_select_names, "pos")
485
486                    if not pos_alias:
487                        pos_alias = new_name(taken_select_names, "pos")
488
489                    alias.set("alias", exp.to_identifier(explode_alias))
490
491                    series_table_alias = series.args["alias"].this
492                    column = exp.If(
493                        this=exp.column(series_alias, table=series_table_alias).eq(
494                            exp.column(pos_alias, table=unnest_source_alias)
495                        ),
496                        true=exp.column(explode_alias, table=unnest_source_alias),
497                    )
498
499                    explode.replace(column)
500
501                    if is_posexplode:
502                        expressions = expression.expressions
503                        expressions.insert(
504                            expressions.index(alias) + 1,
505                            exp.If(
506                                this=exp.column(series_alias, table=series_table_alias).eq(
507                                    exp.column(pos_alias, table=unnest_source_alias)
508                                ),
509                                true=exp.column(pos_alias, table=unnest_source_alias),
510                            ).as_(pos_alias),
511                        )
512                        expression.set("expressions", expressions)
513
514                    if not arrays:
515                        if expression.args.get("from"):
516                            expression.join(series, copy=False, join_type="CROSS")
517                        else:
518                            expression.from_(series, copy=False)
519
520                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
521                    arrays.append(size)
522
523                    # trino doesn't support left join unnest with on conditions
524                    # if it did, this would be much simpler
525                    expression.join(
526                        exp.alias_(
527                            exp.Unnest(
528                                expressions=[explode_arg.copy()],
529                                offset=exp.to_identifier(pos_alias),
530                            ),
531                            unnest_source_alias,
532                            table=[explode_alias],
533                        ),
534                        join_type="CROSS",
535                        copy=False,
536                    )
537
538                    if index_offset != 1:
539                        size = size - 1
540
541                    expression.where(
542                        exp.column(series_alias, table=series_table_alias)
543                        .eq(exp.column(pos_alias, table=unnest_source_alias))
544                        .or_(
545                            (exp.column(series_alias, table=series_table_alias) > size).and_(
546                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
547                            )
548                        ),
549                        copy=False,
550                    )
551
552            if arrays:
553                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
554
555                if index_offset != 1:
556                    end = end - (1 - index_offset)
557                series.expressions[0].set("end", end)
558
559        return expression
560
561    return _explode_to_unnest
562
563
564def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
565    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
566    if (
567        isinstance(expression, exp.PERCENTILES)
568        and not isinstance(expression.parent, exp.WithinGroup)
569        and expression.expression
570    ):
571        column = expression.this.pop()
572        expression.set("this", expression.expression.pop())
573        order = exp.Order(expressions=[exp.Ordered(this=column)])
574        expression = exp.WithinGroup(this=expression, expression=order)
575
576    return expression
577
578
579def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
580    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
581    if (
582        isinstance(expression, exp.WithinGroup)
583        and isinstance(expression.this, exp.PERCENTILES)
584        and isinstance(expression.expression, exp.Order)
585    ):
586        quantile = expression.this.this
587        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
588        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
589
590    return expression
591
592
593def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
594    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
595    if isinstance(expression, exp.With) and expression.recursive:
596        next_name = name_sequence("_c_")
597
598        for cte in expression.expressions:
599            if not cte.args["alias"].columns:
600                query = cte.this
601                if isinstance(query, exp.SetOperation):
602                    query = query.this
603
604                cte.args["alias"].set(
605                    "columns",
606                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
607                )
608
609    return expression
610
611
612def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
613    """Replace 'epoch' in casts by the equivalent date literal."""
614    if (
615        isinstance(expression, (exp.Cast, exp.TryCast))
616        and expression.name.lower() == "epoch"
617        and expression.to.this in exp.DataType.TEMPORAL_TYPES
618    ):
619        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
620
621    return expression
622
623
624def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
625    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
626    if isinstance(expression, exp.Select):
627        for join in expression.args.get("joins") or []:
628            on = join.args.get("on")
629            if on and join.kind in ("SEMI", "ANTI"):
630                subquery = exp.select("1").from_(join.this).where(on)
631                exists = exp.Exists(this=subquery)
632                if join.kind == "ANTI":
633                    exists = exists.not_(copy=False)
634
635                join.pop()
636                expression.where(exists, copy=False)
637
638    return expression
639
640
641def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
642    """
643    Converts a query with a FULL OUTER join to a union of identical queries that
644    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
645    for queries that have a single FULL OUTER join.
646    """
647    if isinstance(expression, exp.Select):
648        full_outer_joins = [
649            (index, join)
650            for index, join in enumerate(expression.args.get("joins") or [])
651            if join.side == "FULL"
652        ]
653
654        if len(full_outer_joins) == 1:
655            expression_copy = expression.copy()
656            expression.set("limit", None)
657            index, full_outer_join = full_outer_joins[0]
658
659            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
660            join_conditions = full_outer_join.args.get("on") or exp.and_(
661                *[
662                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
663                    for col in full_outer_join.args.get("using")
664                ]
665            )
666
667            full_outer_join.set("side", "left")
668            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
669            expression_copy.args["joins"][index].set("side", "right")
670            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
671            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
672            expression.args.pop("order", None)  # remove order by from LEFT side
673
674            return exp.union(expression, expression_copy, copy=False, distinct=False)
675
676    return expression
677
678
679def move_ctes_to_top_level(expression: E) -> E:
680    """
681    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
682    defined at the top-level, so for example queries like:
683
684        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
685
686    are invalid in those dialects. This transformation can be used to ensure all CTEs are
687    moved to the top level so that the final SQL code is valid from a syntax standpoint.
688
689    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
690    """
691    top_level_with = expression.args.get("with")
692    for inner_with in expression.find_all(exp.With):
693        if inner_with.parent is expression:
694            continue
695
696        if not top_level_with:
697            top_level_with = inner_with.pop()
698            expression.set("with", top_level_with)
699        else:
700            if inner_with.recursive:
701                top_level_with.set("recursive", True)
702
703            parent_cte = inner_with.find_ancestor(exp.CTE)
704            inner_with.pop()
705
706            if parent_cte:
707                i = top_level_with.expressions.index(parent_cte)
708                top_level_with.expressions[i:i] = inner_with.expressions
709                top_level_with.set("expressions", top_level_with.expressions)
710            else:
711                top_level_with.set(
712                    "expressions", top_level_with.expressions + inner_with.expressions
713                )
714
715    return expression
716
717
718def ensure_bools(expression: exp.Expression) -> exp.Expression:
719    """Converts numeric values used in conditions into explicit boolean expressions."""
720    from sqlglot.optimizer.canonicalize import ensure_bools
721
722    def _ensure_bool(node: exp.Expression) -> None:
723        if (
724            node.is_number
725            or (
726                not isinstance(node, exp.SubqueryPredicate)
727                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
728            )
729            or (isinstance(node, exp.Column) and not node.type)
730        ):
731            node.replace(node.neq(0))
732
733    for node in expression.walk():
734        ensure_bools(node, _ensure_bool)
735
736    return expression
737
738
739def unqualify_columns(expression: exp.Expression) -> exp.Expression:
740    for column in expression.find_all(exp.Column):
741        # We only wanna pop off the table, db, catalog args
742        for part in column.parts[:-1]:
743            part.pop()
744
745    return expression
746
747
748def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
749    assert isinstance(expression, exp.Create)
750    for constraint in expression.find_all(exp.UniqueColumnConstraint):
751        if constraint.parent:
752            constraint.parent.pop()
753
754    return expression
755
756
757def ctas_with_tmp_tables_to_create_tmp_view(
758    expression: exp.Expression,
759    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
760) -> exp.Expression:
761    assert isinstance(expression, exp.Create)
762    properties = expression.args.get("properties")
763    temporary = any(
764        isinstance(prop, exp.TemporaryProperty)
765        for prop in (properties.expressions if properties else [])
766    )
767
768    # CTAS with temp tables map to CREATE TEMPORARY VIEW
769    if expression.kind == "TABLE" and temporary:
770        if expression.expression:
771            return exp.Create(
772                kind="TEMPORARY VIEW",
773                this=expression.this,
774                expression=expression.expression,
775            )
776        return tmp_storage_provider(expression)
777
778    return expression
779
780
781def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
782    """
783    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
784    PARTITIONED BY value is an array of column names, they are transformed into a schema.
785    The corresponding columns are removed from the create statement.
786    """
787    assert isinstance(expression, exp.Create)
788    has_schema = isinstance(expression.this, exp.Schema)
789    is_partitionable = expression.kind in {"TABLE", "VIEW"}
790
791    if has_schema and is_partitionable:
792        prop = expression.find(exp.PartitionedByProperty)
793        if prop and prop.this and not isinstance(prop.this, exp.Schema):
794            schema = expression.this
795            columns = {v.name.upper() for v in prop.this.expressions}
796            partitions = [col for col in schema.expressions if col.name.upper() in columns]
797            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
798            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
799            expression.set("this", schema)
800
801    return expression
802
803
804def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
805    """
806    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
807
808    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
809    """
810    assert isinstance(expression, exp.Create)
811    prop = expression.find(exp.PartitionedByProperty)
812    if (
813        prop
814        and prop.this
815        and isinstance(prop.this, exp.Schema)
816        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
817    ):
818        prop_this = exp.Tuple(
819            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
820        )
821        schema = expression.this
822        for e in prop.this.expressions:
823            schema.append("expressions", e)
824        prop.set("this", prop_this)
825
826    return expression
827
828
829def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
830    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
831    if isinstance(expression, exp.Struct):
832        expression.set(
833            "expressions",
834            [
835                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
836                for e in expression.expressions
837            ],
838        )
839
840    return expression
841
842
843def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
844    """
845    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
846    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
847
848    For example,
849        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
850        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
851
852    Args:
853        expression: The AST to remove join marks from.
854
855    Returns:
856       The AST with join marks removed.
857    """
858    from sqlglot.optimizer.scope import traverse_scope
859
860    for scope in traverse_scope(expression):
861        query = scope.expression
862
863        where = query.args.get("where")
864        joins = query.args.get("joins")
865
866        if not where or not joins:
867            continue
868
869        query_from = query.args["from"]
870
871        # These keep track of the joins to be replaced
872        new_joins: t.Dict[str, exp.Join] = {}
873        old_joins = {join.alias_or_name: join for join in joins}
874
875        for column in scope.columns:
876            if not column.args.get("join_mark"):
877                continue
878
879            predicate = column.find_ancestor(exp.Predicate, exp.Select)
880            assert isinstance(
881                predicate, exp.Binary
882            ), "Columns can only be marked with (+) when involved in a binary operation"
883
884            predicate_parent = predicate.parent
885            join_predicate = predicate.pop()
886
887            left_columns = [
888                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
889            ]
890            right_columns = [
891                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
892            ]
893
894            assert not (
895                left_columns and right_columns
896            ), "The (+) marker cannot appear in both sides of a binary predicate"
897
898            marked_column_tables = set()
899            for col in left_columns or right_columns:
900                table = col.table
901                assert table, f"Column {col} needs to be qualified with a table"
902
903                col.set("join_mark", False)
904                marked_column_tables.add(table)
905
906            assert (
907                len(marked_column_tables) == 1
908            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
909
910            join_this = old_joins.get(col.table, query_from).this
911            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
912
913            # Upsert new_join into new_joins dictionary
914            new_join_alias_or_name = new_join.alias_or_name
915            existing_join = new_joins.get(new_join_alias_or_name)
916            if existing_join:
917                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
918            else:
919                new_joins[new_join_alias_or_name] = new_join
920
921            # If the parent of the target predicate is a binary node, then it now has only one child
922            if isinstance(predicate_parent, exp.Binary):
923                if predicate_parent.left is None:
924                    predicate_parent.replace(predicate_parent.right)
925                else:
926                    predicate_parent.replace(predicate_parent.left)
927
928        if query_from.alias_or_name in new_joins:
929            only_old_joins = old_joins.keys() - new_joins.keys()
930            assert (
931                len(only_old_joins) >= 1
932            ), "Cannot determine which table to use in the new FROM clause"
933
934            new_from_name = list(only_old_joins)[0]
935            query.set("from", exp.From(this=old_joins[new_from_name].this))
936
937        query.set("joins", list(new_joins.values()))
938
939        if not where.this:
940            where.pop()
941
942    return expression
943
944
945def any_to_exists(expression: exp.Expression) -> exp.Expression:
946    """
947    Transform ANY operator to Spark's EXISTS
948
949    For example,
950        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
951        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
952
953    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
954    transformation
955    """
956    if isinstance(expression, exp.Select):
957        for any in expression.find_all(exp.Any):
958            this = any.this
959            if isinstance(this, exp.Query):
960                continue
961
962            binop = any.parent
963            if isinstance(binop, exp.Binary):
964                lambda_arg = exp.to_identifier("x")
965                any.replace(lambda_arg)
966                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
967                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
968
969    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
184    ):
185        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
186
187        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189
190        order = expression.args.get("order")
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number_window_alias)
197        expression.select(window, copy=False)
198
199        # We add aliases to the projections so that we can safely reference them in the outer query
200        new_selects = []
201        taken_names = {row_number_window_alias}
202        for select in expression.selects[:-1]:
203            if select.is_star:
204                new_selects = [exp.Star()]
205                break
206
207            if not isinstance(select, exp.Alias):
208                alias = find_new_name(taken_names, select.output_name or "_col")
209                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
210                select = select.replace(exp.alias_(select, alias, quoted=quoted))
211
212            taken_names.add(select.output_name)
213            new_selects.append(select.args["alias"])
214
215        return (
216            exp.select(*new_selects, copy=False)
217            .from_(expression.subquery("_t", copy=False), copy=False)
218            .where(exp.column(row_number_window_alias).eq(1), copy=False)
219        )
220
221    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
224def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
225    """
226    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
227
228    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
229    https://docs.snowflake.com/en/sql-reference/constructs/qualify
230
231    Some dialects don't support window functions in the WHERE clause, so we need to include them as
232    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
233    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
234    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
235    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
236    corresponding expression to avoid creating invalid column references.
237    """
238    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
239        taken = set(expression.named_selects)
240        for select in expression.selects:
241            if not select.alias_or_name:
242                alias = find_new_name(taken, "_c")
243                select.replace(exp.alias_(select, alias))
244                taken.add(alias)
245
246        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
247            alias_or_name = select.alias_or_name
248            identifier = select.args.get("alias") or select.this
249            if isinstance(identifier, exp.Identifier):
250                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
251            return alias_or_name
252
253        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
254        qualify_filters = expression.args["qualify"].pop().this
255        expression_by_alias = {
256            select.alias: select.this
257            for select in expression.selects
258            if isinstance(select, exp.Alias)
259        }
260
261        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
262        for select_candidate in qualify_filters.find_all(select_candidates):
263            if isinstance(select_candidate, exp.Window):
264                if expression_by_alias:
265                    for column in select_candidate.find_all(exp.Column):
266                        expr = expression_by_alias.get(column.name)
267                        if expr:
268                            column.replace(expr)
269
270                alias = find_new_name(expression.named_selects, "_w")
271                expression.select(exp.alias_(select_candidate, alias), copy=False)
272                column = exp.column(alias)
273
274                if isinstance(select_candidate.parent, exp.Qualify):
275                    qualify_filters = column
276                else:
277                    select_candidate.replace(column)
278            elif select_candidate.name not in expression.named_selects:
279                expression.select(select_candidate.copy(), copy=False)
280
281        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
282            qualify_filters, copy=False
283        )
284
285    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
288def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
289    """
290    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
291    other expressions. This transforms removes the precision from parameterized types in expressions.
292    """
293    for node in expression.find_all(exp.DataType):
294        node.set(
295            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
296        )
297
298    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
301def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
302    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
303    from sqlglot.optimizer.scope import find_all_in_scope
304
305    if isinstance(expression, exp.Select):
306        unnest_aliases = {
307            unnest.alias
308            for unnest in find_all_in_scope(expression, exp.Unnest)
309            if isinstance(unnest.parent, (exp.From, exp.Join))
310        }
311        if unnest_aliases:
312            for column in expression.find_all(exp.Column):
313                if column.table in unnest_aliases:
314                    column.set("table", None)
315                elif column.db in unnest_aliases:
316                    column.set("db", None)
317
318    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
321def unnest_to_explode(
322    expression: exp.Expression,
323    unnest_using_arrays_zip: bool = True,
324) -> exp.Expression:
325    """Convert cross join unnest into lateral view explode."""
326
327    def _unnest_zip_exprs(
328        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
329    ) -> t.List[exp.Expression]:
330        if has_multi_expr:
331            if not unnest_using_arrays_zip:
332                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
333
334            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
335            zip_exprs: t.List[exp.Expression] = [
336                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
337            ]
338            u.set("expressions", zip_exprs)
339            return zip_exprs
340        return unnest_exprs
341
342    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
343        if u.args.get("offset"):
344            return exp.Posexplode
345        return exp.Inline if has_multi_expr else exp.Explode
346
347    if isinstance(expression, exp.Select):
348        from_ = expression.args.get("from")
349
350        if from_ and isinstance(from_.this, exp.Unnest):
351            unnest = from_.this
352            alias = unnest.args.get("alias")
353            exprs = unnest.expressions
354            has_multi_expr = len(exprs) > 1
355            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
356
357            unnest.replace(
358                exp.Table(
359                    this=_udtf_type(unnest, has_multi_expr)(
360                        this=this,
361                        expressions=expressions,
362                    ),
363                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
364                )
365            )
366
367        joins = expression.args.get("joins") or []
368        for join in list(joins):
369            join_expr = join.this
370
371            is_lateral = isinstance(join_expr, exp.Lateral)
372
373            unnest = join_expr.this if is_lateral else join_expr
374
375            if isinstance(unnest, exp.Unnest):
376                if is_lateral:
377                    alias = join_expr.args.get("alias")
378                else:
379                    alias = unnest.args.get("alias")
380                exprs = unnest.expressions
381                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
382                has_multi_expr = len(exprs) > 1
383                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
384
385                joins.remove(join)
386
387                alias_cols = alias.columns if alias else []
388
389                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
390                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
391                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
392
393                if not has_multi_expr and len(alias_cols) not in (1, 2):
394                    raise UnsupportedError(
395                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
396                    )
397
398                for e, column in zip(exprs, alias_cols):
399                    expression.append(
400                        "laterals",
401                        exp.Lateral(
402                            this=_udtf_type(unnest, has_multi_expr)(this=e),
403                            view=True,
404                            alias=exp.TableAlias(
405                                this=alias.this,  # type: ignore
406                                columns=alias_cols,
407                            ),
408                        ),
409                    )
410
411    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
414def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
415    """Convert explode/posexplode into unnest."""
416
417    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
418        if isinstance(expression, exp.Select):
419            from sqlglot.optimizer.scope import Scope
420
421            taken_select_names = set(expression.named_selects)
422            taken_source_names = {name for name, _ in Scope(expression).references}
423
424            def new_name(names: t.Set[str], name: str) -> str:
425                name = find_new_name(names, name)
426                names.add(name)
427                return name
428
429            arrays: t.List[exp.Condition] = []
430            series_alias = new_name(taken_select_names, "pos")
431            series = exp.alias_(
432                exp.Unnest(
433                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
434                ),
435                new_name(taken_source_names, "_u"),
436                table=[series_alias],
437            )
438
439            # we use list here because expression.selects is mutated inside the loop
440            for select in list(expression.selects):
441                explode = select.find(exp.Explode)
442
443                if explode:
444                    pos_alias = ""
445                    explode_alias = ""
446
447                    if isinstance(select, exp.Alias):
448                        explode_alias = select.args["alias"]
449                        alias = select
450                    elif isinstance(select, exp.Aliases):
451                        pos_alias = select.aliases[0]
452                        explode_alias = select.aliases[1]
453                        alias = select.replace(exp.alias_(select.this, "", copy=False))
454                    else:
455                        alias = select.replace(exp.alias_(select, ""))
456                        explode = alias.find(exp.Explode)
457                        assert explode
458
459                    is_posexplode = isinstance(explode, exp.Posexplode)
460                    explode_arg = explode.this
461
462                    if isinstance(explode, exp.ExplodeOuter):
463                        bracket = explode_arg[0]
464                        bracket.set("safe", True)
465                        bracket.set("offset", True)
466                        explode_arg = exp.func(
467                            "IF",
468                            exp.func(
469                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
470                            ).eq(0),
471                            exp.array(bracket, copy=False),
472                            explode_arg,
473                        )
474
475                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
476                    if isinstance(explode_arg, exp.Column):
477                        taken_select_names.add(explode_arg.output_name)
478
479                    unnest_source_alias = new_name(taken_source_names, "_u")
480
481                    if not explode_alias:
482                        explode_alias = new_name(taken_select_names, "col")
483
484                        if is_posexplode:
485                            pos_alias = new_name(taken_select_names, "pos")
486
487                    if not pos_alias:
488                        pos_alias = new_name(taken_select_names, "pos")
489
490                    alias.set("alias", exp.to_identifier(explode_alias))
491
492                    series_table_alias = series.args["alias"].this
493                    column = exp.If(
494                        this=exp.column(series_alias, table=series_table_alias).eq(
495                            exp.column(pos_alias, table=unnest_source_alias)
496                        ),
497                        true=exp.column(explode_alias, table=unnest_source_alias),
498                    )
499
500                    explode.replace(column)
501
502                    if is_posexplode:
503                        expressions = expression.expressions
504                        expressions.insert(
505                            expressions.index(alias) + 1,
506                            exp.If(
507                                this=exp.column(series_alias, table=series_table_alias).eq(
508                                    exp.column(pos_alias, table=unnest_source_alias)
509                                ),
510                                true=exp.column(pos_alias, table=unnest_source_alias),
511                            ).as_(pos_alias),
512                        )
513                        expression.set("expressions", expressions)
514
515                    if not arrays:
516                        if expression.args.get("from"):
517                            expression.join(series, copy=False, join_type="CROSS")
518                        else:
519                            expression.from_(series, copy=False)
520
521                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
522                    arrays.append(size)
523
524                    # trino doesn't support left join unnest with on conditions
525                    # if it did, this would be much simpler
526                    expression.join(
527                        exp.alias_(
528                            exp.Unnest(
529                                expressions=[explode_arg.copy()],
530                                offset=exp.to_identifier(pos_alias),
531                            ),
532                            unnest_source_alias,
533                            table=[explode_alias],
534                        ),
535                        join_type="CROSS",
536                        copy=False,
537                    )
538
539                    if index_offset != 1:
540                        size = size - 1
541
542                    expression.where(
543                        exp.column(series_alias, table=series_table_alias)
544                        .eq(exp.column(pos_alias, table=unnest_source_alias))
545                        .or_(
546                            (exp.column(series_alias, table=series_table_alias) > size).and_(
547                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
548                            )
549                        ),
550                        copy=False,
551                    )
552
553            if arrays:
554                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
555
556                if index_offset != 1:
557                    end = end - (1 - index_offset)
558                series.expressions[0].set("end", end)
559
560        return expression
561
562    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
565def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
566    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
567    if (
568        isinstance(expression, exp.PERCENTILES)
569        and not isinstance(expression.parent, exp.WithinGroup)
570        and expression.expression
571    ):
572        column = expression.this.pop()
573        expression.set("this", expression.expression.pop())
574        order = exp.Order(expressions=[exp.Ordered(this=column)])
575        expression = exp.WithinGroup(this=expression, expression=order)
576
577    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
580def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
581    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
582    if (
583        isinstance(expression, exp.WithinGroup)
584        and isinstance(expression.this, exp.PERCENTILES)
585        and isinstance(expression.expression, exp.Order)
586    ):
587        quantile = expression.this.this
588        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
589        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
590
591    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
594def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
595    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
596    if isinstance(expression, exp.With) and expression.recursive:
597        next_name = name_sequence("_c_")
598
599        for cte in expression.expressions:
600            if not cte.args["alias"].columns:
601                query = cte.this
602                if isinstance(query, exp.SetOperation):
603                    query = query.this
604
605                cte.args["alias"].set(
606                    "columns",
607                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
608                )
609
610    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
613def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
614    """Replace 'epoch' in casts by the equivalent date literal."""
615    if (
616        isinstance(expression, (exp.Cast, exp.TryCast))
617        and expression.name.lower() == "epoch"
618        and expression.to.this in exp.DataType.TEMPORAL_TYPES
619    ):
620        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
621
622    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
625def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
626    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
627    if isinstance(expression, exp.Select):
628        for join in expression.args.get("joins") or []:
629            on = join.args.get("on")
630            if on and join.kind in ("SEMI", "ANTI"):
631                subquery = exp.select("1").from_(join.this).where(on)
632                exists = exp.Exists(this=subquery)
633                if join.kind == "ANTI":
634                    exists = exists.not_(copy=False)
635
636                join.pop()
637                expression.where(exists, copy=False)
638
639    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
642def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
643    """
644    Converts a query with a FULL OUTER join to a union of identical queries that
645    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
646    for queries that have a single FULL OUTER join.
647    """
648    if isinstance(expression, exp.Select):
649        full_outer_joins = [
650            (index, join)
651            for index, join in enumerate(expression.args.get("joins") or [])
652            if join.side == "FULL"
653        ]
654
655        if len(full_outer_joins) == 1:
656            expression_copy = expression.copy()
657            expression.set("limit", None)
658            index, full_outer_join = full_outer_joins[0]
659
660            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
661            join_conditions = full_outer_join.args.get("on") or exp.and_(
662                *[
663                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
664                    for col in full_outer_join.args.get("using")
665                ]
666            )
667
668            full_outer_join.set("side", "left")
669            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
670            expression_copy.args["joins"][index].set("side", "right")
671            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
672            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
673            expression.args.pop("order", None)  # remove order by from LEFT side
674
675            return exp.union(expression, expression_copy, copy=False, distinct=False)
676
677    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
680def move_ctes_to_top_level(expression: E) -> E:
681    """
682    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
683    defined at the top-level, so for example queries like:
684
685        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
686
687    are invalid in those dialects. This transformation can be used to ensure all CTEs are
688    moved to the top level so that the final SQL code is valid from a syntax standpoint.
689
690    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
691    """
692    top_level_with = expression.args.get("with")
693    for inner_with in expression.find_all(exp.With):
694        if inner_with.parent is expression:
695            continue
696
697        if not top_level_with:
698            top_level_with = inner_with.pop()
699            expression.set("with", top_level_with)
700        else:
701            if inner_with.recursive:
702                top_level_with.set("recursive", True)
703
704            parent_cte = inner_with.find_ancestor(exp.CTE)
705            inner_with.pop()
706
707            if parent_cte:
708                i = top_level_with.expressions.index(parent_cte)
709                top_level_with.expressions[i:i] = inner_with.expressions
710                top_level_with.set("expressions", top_level_with.expressions)
711            else:
712                top_level_with.set(
713                    "expressions", top_level_with.expressions + inner_with.expressions
714                )
715
716    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
719def ensure_bools(expression: exp.Expression) -> exp.Expression:
720    """Converts numeric values used in conditions into explicit boolean expressions."""
721    from sqlglot.optimizer.canonicalize import ensure_bools
722
723    def _ensure_bool(node: exp.Expression) -> None:
724        if (
725            node.is_number
726            or (
727                not isinstance(node, exp.SubqueryPredicate)
728                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
729            )
730            or (isinstance(node, exp.Column) and not node.type)
731        ):
732            node.replace(node.neq(0))
733
734    for node in expression.walk():
735        ensure_bools(node, _ensure_bool)
736
737    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
740def unqualify_columns(expression: exp.Expression) -> exp.Expression:
741    for column in expression.find_all(exp.Column):
742        # We only wanna pop off the table, db, catalog args
743        for part in column.parts[:-1]:
744            part.pop()
745
746    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
749def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
750    assert isinstance(expression, exp.Create)
751    for constraint in expression.find_all(exp.UniqueColumnConstraint):
752        if constraint.parent:
753            constraint.parent.pop()
754
755    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
758def ctas_with_tmp_tables_to_create_tmp_view(
759    expression: exp.Expression,
760    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
761) -> exp.Expression:
762    assert isinstance(expression, exp.Create)
763    properties = expression.args.get("properties")
764    temporary = any(
765        isinstance(prop, exp.TemporaryProperty)
766        for prop in (properties.expressions if properties else [])
767    )
768
769    # CTAS with temp tables map to CREATE TEMPORARY VIEW
770    if expression.kind == "TABLE" and temporary:
771        if expression.expression:
772            return exp.Create(
773                kind="TEMPORARY VIEW",
774                this=expression.this,
775                expression=expression.expression,
776            )
777        return tmp_storage_provider(expression)
778
779    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
782def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
783    """
784    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
785    PARTITIONED BY value is an array of column names, they are transformed into a schema.
786    The corresponding columns are removed from the create statement.
787    """
788    assert isinstance(expression, exp.Create)
789    has_schema = isinstance(expression.this, exp.Schema)
790    is_partitionable = expression.kind in {"TABLE", "VIEW"}
791
792    if has_schema and is_partitionable:
793        prop = expression.find(exp.PartitionedByProperty)
794        if prop and prop.this and not isinstance(prop.this, exp.Schema):
795            schema = expression.this
796            columns = {v.name.upper() for v in prop.this.expressions}
797            partitions = [col for col in schema.expressions if col.name.upper() in columns]
798            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
799            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
800            expression.set("this", schema)
801
802    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
805def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
806    """
807    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
808
809    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
810    """
811    assert isinstance(expression, exp.Create)
812    prop = expression.find(exp.PartitionedByProperty)
813    if (
814        prop
815        and prop.this
816        and isinstance(prop.this, exp.Schema)
817        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
818    ):
819        prop_this = exp.Tuple(
820            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
821        )
822        schema = expression.this
823        for e in prop.this.expressions:
824            schema.append("expressions", e)
825        prop.set("this", prop_this)
826
827    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
830def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
831    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
832    if isinstance(expression, exp.Struct):
833        expression.set(
834            "expressions",
835            [
836                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
837                for e in expression.expressions
838            ],
839        )
840
841    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
844def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
845    """
846    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
847    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
848
849    For example,
850        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
851        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
852
853    Args:
854        expression: The AST to remove join marks from.
855
856    Returns:
857       The AST with join marks removed.
858    """
859    from sqlglot.optimizer.scope import traverse_scope
860
861    for scope in traverse_scope(expression):
862        query = scope.expression
863
864        where = query.args.get("where")
865        joins = query.args.get("joins")
866
867        if not where or not joins:
868            continue
869
870        query_from = query.args["from"]
871
872        # These keep track of the joins to be replaced
873        new_joins: t.Dict[str, exp.Join] = {}
874        old_joins = {join.alias_or_name: join for join in joins}
875
876        for column in scope.columns:
877            if not column.args.get("join_mark"):
878                continue
879
880            predicate = column.find_ancestor(exp.Predicate, exp.Select)
881            assert isinstance(
882                predicate, exp.Binary
883            ), "Columns can only be marked with (+) when involved in a binary operation"
884
885            predicate_parent = predicate.parent
886            join_predicate = predicate.pop()
887
888            left_columns = [
889                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
890            ]
891            right_columns = [
892                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
893            ]
894
895            assert not (
896                left_columns and right_columns
897            ), "The (+) marker cannot appear in both sides of a binary predicate"
898
899            marked_column_tables = set()
900            for col in left_columns or right_columns:
901                table = col.table
902                assert table, f"Column {col} needs to be qualified with a table"
903
904                col.set("join_mark", False)
905                marked_column_tables.add(table)
906
907            assert (
908                len(marked_column_tables) == 1
909            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
910
911            join_this = old_joins.get(col.table, query_from).this
912            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
913
914            # Upsert new_join into new_joins dictionary
915            new_join_alias_or_name = new_join.alias_or_name
916            existing_join = new_joins.get(new_join_alias_or_name)
917            if existing_join:
918                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
919            else:
920                new_joins[new_join_alias_or_name] = new_join
921
922            # If the parent of the target predicate is a binary node, then it now has only one child
923            if isinstance(predicate_parent, exp.Binary):
924                if predicate_parent.left is None:
925                    predicate_parent.replace(predicate_parent.right)
926                else:
927                    predicate_parent.replace(predicate_parent.left)
928
929        if query_from.alias_or_name in new_joins:
930            only_old_joins = old_joins.keys() - new_joins.keys()
931            assert (
932                len(only_old_joins) >= 1
933            ), "Cannot determine which table to use in the new FROM clause"
934
935            new_from_name = list(only_old_joins)[0]
936            query.set("from", exp.From(this=old_joins[new_from_name].this))
937
938        query.set("joins", list(new_joins.values()))
939
940        if not where.this:
941            where.pop()
942
943    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
946def any_to_exists(expression: exp.Expression) -> exp.Expression:
947    """
948    Transform ANY operator to Spark's EXISTS
949
950    For example,
951        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
952        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
953
954    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
955    transformation
956    """
957    if isinstance(expression, exp.Select):
958        for any in expression.find_all(exp.Any):
959            this = any.this
960            if isinstance(this, exp.Query):
961                continue
962
963            binop = any.parent
964            if isinstance(binop, exp.Binary):
965                lambda_arg = exp.to_identifier("x")
966                any.replace(lambda_arg)
967                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
968                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
969
970    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation