Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def preprocess(
 13    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 14) -> t.Callable[[Generator, exp.Expression], str]:
 15    """
 16    Creates a new transform by chaining a sequence of transformations and converts the resulting
 17    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 18    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 19
 20    Args:
 21        transforms: sequence of transform functions. These will be called in order.
 22
 23    Returns:
 24        Function that can be used as a generator transform.
 25    """
 26
 27    def _to_sql(self, expression: exp.Expression) -> str:
 28        expression_type = type(expression)
 29
 30        expression = transforms[0](expression)
 31        for transform in transforms[1:]:
 32            expression = transform(expression)
 33
 34        _sql_handler = getattr(self, expression.key + "_sql", None)
 35        if _sql_handler:
 36            return _sql_handler(expression)
 37
 38        transforms_handler = self.TRANSFORMS.get(type(expression))
 39        if transforms_handler:
 40            if expression_type is type(expression):
 41                if isinstance(expression, exp.Func):
 42                    return self.function_fallback_sql(expression)
 43
 44                # Ensures we don't enter an infinite loop. This can happen when the original expression
 45                # has the same type as the final expression and there's no _sql method available for it,
 46                # because then it'd re-enter _to_sql.
 47                raise ValueError(
 48                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 49                )
 50
 51            return transforms_handler(self, expression)
 52
 53        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 54
 55    return _to_sql
 56
 57
 58def unalias_group(expression: exp.Expression) -> exp.Expression:
 59    """
 60    Replace references to select aliases in GROUP BY clauses.
 61
 62    Example:
 63        >>> import sqlglot
 64        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 65        'SELECT a AS b FROM x GROUP BY 1'
 66
 67    Args:
 68        expression: the expression that will be transformed.
 69
 70    Returns:
 71        The transformed expression.
 72    """
 73    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 74        aliased_selects = {
 75            e.alias: i
 76            for i, e in enumerate(expression.parent.expressions, start=1)
 77            if isinstance(e, exp.Alias)
 78        }
 79
 80        for group_by in expression.expressions:
 81            if (
 82                isinstance(group_by, exp.Column)
 83                and not group_by.table
 84                and group_by.name in aliased_selects
 85            ):
 86                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 87
 88    return expression
 89
 90
 91def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 92    """
 93    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 94
 95    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 96
 97    Args:
 98        expression: the expression that will be transformed.
 99
100    Returns:
101        The transformed expression.
102    """
103    if (
104        isinstance(expression, exp.Select)
105        and expression.args.get("distinct")
106        and expression.args["distinct"].args.get("on")
107        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
108    ):
109        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
110        outer_selects = expression.selects
111        row_number = find_new_name(expression.named_selects, "_row_number")
112        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
113        order = expression.args.get("order")
114
115        if order:
116            window.set("order", order.pop())
117        else:
118            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
119
120        window = exp.alias_(window, row_number)
121        expression.select(window, copy=False)
122
123        return (
124            exp.select(*outer_selects, copy=False)
125            .from_(expression.subquery("_t", copy=False), copy=False)
126            .where(exp.column(row_number).eq(1), copy=False)
127        )
128
129    return expression
130
131
132def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
133    """
134    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
135
136    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
137    https://docs.snowflake.com/en/sql-reference/constructs/qualify
138
139    Some dialects don't support window functions in the WHERE clause, so we need to include them as
140    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
141    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
142    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
143    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
144    corresponding expression to avoid creating invalid column references.
145    """
146    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
147        taken = set(expression.named_selects)
148        for select in expression.selects:
149            if not select.alias_or_name:
150                alias = find_new_name(taken, "_c")
151                select.replace(exp.alias_(select, alias))
152                taken.add(alias)
153
154        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
155            alias_or_name = select.alias_or_name
156            identifier = select.args.get("alias") or select.this
157            if isinstance(identifier, exp.Identifier):
158                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
159            return alias_or_name
160
161        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
162        qualify_filters = expression.args["qualify"].pop().this
163        expression_by_alias = {
164            select.alias: select.this
165            for select in expression.selects
166            if isinstance(select, exp.Alias)
167        }
168
169        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
170        for select_candidate in qualify_filters.find_all(select_candidates):
171            if isinstance(select_candidate, exp.Window):
172                if expression_by_alias:
173                    for column in select_candidate.find_all(exp.Column):
174                        expr = expression_by_alias.get(column.name)
175                        if expr:
176                            column.replace(expr)
177
178                alias = find_new_name(expression.named_selects, "_w")
179                expression.select(exp.alias_(select_candidate, alias), copy=False)
180                column = exp.column(alias)
181
182                if isinstance(select_candidate.parent, exp.Qualify):
183                    qualify_filters = column
184                else:
185                    select_candidate.replace(column)
186            elif select_candidate.name not in expression.named_selects:
187                expression.select(select_candidate.copy(), copy=False)
188
189        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
190            qualify_filters, copy=False
191        )
192
193    return expression
194
195
196def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
197    """
198    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
199    other expressions. This transforms removes the precision from parameterized types in expressions.
200    """
201    for node in expression.find_all(exp.DataType):
202        node.set(
203            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
204        )
205
206    return expression
207
208
209def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
210    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
211    from sqlglot.optimizer.scope import find_all_in_scope
212
213    if isinstance(expression, exp.Select):
214        unnest_aliases = {
215            unnest.alias
216            for unnest in find_all_in_scope(expression, exp.Unnest)
217            if isinstance(unnest.parent, (exp.From, exp.Join))
218        }
219        if unnest_aliases:
220            for column in expression.find_all(exp.Column):
221                if column.table in unnest_aliases:
222                    column.set("table", None)
223                elif column.db in unnest_aliases:
224                    column.set("db", None)
225
226    return expression
227
228
229def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
230    """Convert cross join unnest into lateral view explode."""
231    if isinstance(expression, exp.Select):
232        from_ = expression.args.get("from")
233
234        if from_ and isinstance(from_.this, exp.Unnest):
235            unnest = from_.this
236            alias = unnest.args.get("alias")
237            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
238            this, *expressions = unnest.expressions
239            unnest.replace(
240                exp.Table(
241                    this=udtf(
242                        this=this,
243                        expressions=expressions,
244                    ),
245                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
246                )
247            )
248
249        for join in expression.args.get("joins") or []:
250            unnest = join.this
251
252            if isinstance(unnest, exp.Unnest):
253                alias = unnest.args.get("alias")
254                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
255
256                expression.args["joins"].remove(join)
257
258                for e, column in zip(unnest.expressions, alias.columns if alias else []):
259                    expression.append(
260                        "laterals",
261                        exp.Lateral(
262                            this=udtf(this=e),
263                            view=True,
264                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
265                        ),
266                    )
267
268    return expression
269
270
271def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
272    """Convert explode/posexplode into unnest."""
273
274    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
275        if isinstance(expression, exp.Select):
276            from sqlglot.optimizer.scope import Scope
277
278            taken_select_names = set(expression.named_selects)
279            taken_source_names = {name for name, _ in Scope(expression).references}
280
281            def new_name(names: t.Set[str], name: str) -> str:
282                name = find_new_name(names, name)
283                names.add(name)
284                return name
285
286            arrays: t.List[exp.Condition] = []
287            series_alias = new_name(taken_select_names, "pos")
288            series = exp.alias_(
289                exp.Unnest(
290                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
291                ),
292                new_name(taken_source_names, "_u"),
293                table=[series_alias],
294            )
295
296            # we use list here because expression.selects is mutated inside the loop
297            for select in list(expression.selects):
298                explode = select.find(exp.Explode)
299
300                if explode:
301                    pos_alias = ""
302                    explode_alias = ""
303
304                    if isinstance(select, exp.Alias):
305                        explode_alias = select.args["alias"]
306                        alias = select
307                    elif isinstance(select, exp.Aliases):
308                        pos_alias = select.aliases[0]
309                        explode_alias = select.aliases[1]
310                        alias = select.replace(exp.alias_(select.this, "", copy=False))
311                    else:
312                        alias = select.replace(exp.alias_(select, ""))
313                        explode = alias.find(exp.Explode)
314                        assert explode
315
316                    is_posexplode = isinstance(explode, exp.Posexplode)
317                    explode_arg = explode.this
318
319                    if isinstance(explode, exp.ExplodeOuter):
320                        bracket = explode_arg[0]
321                        bracket.set("safe", True)
322                        bracket.set("offset", True)
323                        explode_arg = exp.func(
324                            "IF",
325                            exp.func(
326                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
327                            ).eq(0),
328                            exp.array(bracket, copy=False),
329                            explode_arg,
330                        )
331
332                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
333                    if isinstance(explode_arg, exp.Column):
334                        taken_select_names.add(explode_arg.output_name)
335
336                    unnest_source_alias = new_name(taken_source_names, "_u")
337
338                    if not explode_alias:
339                        explode_alias = new_name(taken_select_names, "col")
340
341                        if is_posexplode:
342                            pos_alias = new_name(taken_select_names, "pos")
343
344                    if not pos_alias:
345                        pos_alias = new_name(taken_select_names, "pos")
346
347                    alias.set("alias", exp.to_identifier(explode_alias))
348
349                    series_table_alias = series.args["alias"].this
350                    column = exp.If(
351                        this=exp.column(series_alias, table=series_table_alias).eq(
352                            exp.column(pos_alias, table=unnest_source_alias)
353                        ),
354                        true=exp.column(explode_alias, table=unnest_source_alias),
355                    )
356
357                    explode.replace(column)
358
359                    if is_posexplode:
360                        expressions = expression.expressions
361                        expressions.insert(
362                            expressions.index(alias) + 1,
363                            exp.If(
364                                this=exp.column(series_alias, table=series_table_alias).eq(
365                                    exp.column(pos_alias, table=unnest_source_alias)
366                                ),
367                                true=exp.column(pos_alias, table=unnest_source_alias),
368                            ).as_(pos_alias),
369                        )
370                        expression.set("expressions", expressions)
371
372                    if not arrays:
373                        if expression.args.get("from"):
374                            expression.join(series, copy=False, join_type="CROSS")
375                        else:
376                            expression.from_(series, copy=False)
377
378                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
379                    arrays.append(size)
380
381                    # trino doesn't support left join unnest with on conditions
382                    # if it did, this would be much simpler
383                    expression.join(
384                        exp.alias_(
385                            exp.Unnest(
386                                expressions=[explode_arg.copy()],
387                                offset=exp.to_identifier(pos_alias),
388                            ),
389                            unnest_source_alias,
390                            table=[explode_alias],
391                        ),
392                        join_type="CROSS",
393                        copy=False,
394                    )
395
396                    if index_offset != 1:
397                        size = size - 1
398
399                    expression.where(
400                        exp.column(series_alias, table=series_table_alias)
401                        .eq(exp.column(pos_alias, table=unnest_source_alias))
402                        .or_(
403                            (exp.column(series_alias, table=series_table_alias) > size).and_(
404                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
405                            )
406                        ),
407                        copy=False,
408                    )
409
410            if arrays:
411                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
412
413                if index_offset != 1:
414                    end = end - (1 - index_offset)
415                series.expressions[0].set("end", end)
416
417        return expression
418
419    return _explode_to_unnest
420
421
422def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
423    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
424    if (
425        isinstance(expression, exp.PERCENTILES)
426        and not isinstance(expression.parent, exp.WithinGroup)
427        and expression.expression
428    ):
429        column = expression.this.pop()
430        expression.set("this", expression.expression.pop())
431        order = exp.Order(expressions=[exp.Ordered(this=column)])
432        expression = exp.WithinGroup(this=expression, expression=order)
433
434    return expression
435
436
437def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
438    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
439    if (
440        isinstance(expression, exp.WithinGroup)
441        and isinstance(expression.this, exp.PERCENTILES)
442        and isinstance(expression.expression, exp.Order)
443    ):
444        quantile = expression.this.this
445        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
446        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
447
448    return expression
449
450
451def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
452    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
453    if isinstance(expression, exp.With) and expression.recursive:
454        next_name = name_sequence("_c_")
455
456        for cte in expression.expressions:
457            if not cte.args["alias"].columns:
458                query = cte.this
459                if isinstance(query, exp.SetOperation):
460                    query = query.this
461
462                cte.args["alias"].set(
463                    "columns",
464                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
465                )
466
467    return expression
468
469
470def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
471    """Replace 'epoch' in casts by the equivalent date literal."""
472    if (
473        isinstance(expression, (exp.Cast, exp.TryCast))
474        and expression.name.lower() == "epoch"
475        and expression.to.this in exp.DataType.TEMPORAL_TYPES
476    ):
477        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
478
479    return expression
480
481
482def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
483    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
484    if isinstance(expression, exp.Select):
485        for join in expression.args.get("joins") or []:
486            on = join.args.get("on")
487            if on and join.kind in ("SEMI", "ANTI"):
488                subquery = exp.select("1").from_(join.this).where(on)
489                exists = exp.Exists(this=subquery)
490                if join.kind == "ANTI":
491                    exists = exists.not_(copy=False)
492
493                join.pop()
494                expression.where(exists, copy=False)
495
496    return expression
497
498
499def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
500    """
501    Converts a query with a FULL OUTER join to a union of identical queries that
502    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
503    for queries that have a single FULL OUTER join.
504    """
505    if isinstance(expression, exp.Select):
506        full_outer_joins = [
507            (index, join)
508            for index, join in enumerate(expression.args.get("joins") or [])
509            if join.side == "FULL"
510        ]
511
512        if len(full_outer_joins) == 1:
513            expression_copy = expression.copy()
514            expression.set("limit", None)
515            index, full_outer_join = full_outer_joins[0]
516            full_outer_join.set("side", "left")
517            expression_copy.args["joins"][index].set("side", "right")
518            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
519
520            return exp.union(expression, expression_copy, copy=False)
521
522    return expression
523
524
525def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
526    """
527    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
528    defined at the top-level, so for example queries like:
529
530        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
531
532    are invalid in those dialects. This transformation can be used to ensure all CTEs are
533    moved to the top level so that the final SQL code is valid from a syntax standpoint.
534
535    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
536    """
537    top_level_with = expression.args.get("with")
538    for inner_with in expression.find_all(exp.With):
539        if inner_with.parent is expression:
540            continue
541
542        if not top_level_with:
543            top_level_with = inner_with.pop()
544            expression.set("with", top_level_with)
545        else:
546            if inner_with.recursive:
547                top_level_with.set("recursive", True)
548
549            parent_cte = inner_with.find_ancestor(exp.CTE)
550            inner_with.pop()
551
552            if parent_cte:
553                i = top_level_with.expressions.index(parent_cte)
554                top_level_with.expressions[i:i] = inner_with.expressions
555                top_level_with.set("expressions", top_level_with.expressions)
556            else:
557                top_level_with.set(
558                    "expressions", top_level_with.expressions + inner_with.expressions
559                )
560
561    return expression
562
563
564def ensure_bools(expression: exp.Expression) -> exp.Expression:
565    """Converts numeric values used in conditions into explicit boolean expressions."""
566    from sqlglot.optimizer.canonicalize import ensure_bools
567
568    def _ensure_bool(node: exp.Expression) -> None:
569        if (
570            node.is_number
571            or (
572                not isinstance(node, exp.SubqueryPredicate)
573                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
574            )
575            or (isinstance(node, exp.Column) and not node.type)
576        ):
577            node.replace(node.neq(0))
578
579    for node in expression.walk():
580        ensure_bools(node, _ensure_bool)
581
582    return expression
583
584
585def unqualify_columns(expression: exp.Expression) -> exp.Expression:
586    for column in expression.find_all(exp.Column):
587        # We only wanna pop off the table, db, catalog args
588        for part in column.parts[:-1]:
589            part.pop()
590
591    return expression
592
593
594def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
595    assert isinstance(expression, exp.Create)
596    for constraint in expression.find_all(exp.UniqueColumnConstraint):
597        if constraint.parent:
598            constraint.parent.pop()
599
600    return expression
601
602
603def ctas_with_tmp_tables_to_create_tmp_view(
604    expression: exp.Expression,
605    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
606) -> exp.Expression:
607    assert isinstance(expression, exp.Create)
608    properties = expression.args.get("properties")
609    temporary = any(
610        isinstance(prop, exp.TemporaryProperty)
611        for prop in (properties.expressions if properties else [])
612    )
613
614    # CTAS with temp tables map to CREATE TEMPORARY VIEW
615    if expression.kind == "TABLE" and temporary:
616        if expression.expression:
617            return exp.Create(
618                kind="TEMPORARY VIEW",
619                this=expression.this,
620                expression=expression.expression,
621            )
622        return tmp_storage_provider(expression)
623
624    return expression
625
626
627def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
628    """
629    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
630    PARTITIONED BY value is an array of column names, they are transformed into a schema.
631    The corresponding columns are removed from the create statement.
632    """
633    assert isinstance(expression, exp.Create)
634    has_schema = isinstance(expression.this, exp.Schema)
635    is_partitionable = expression.kind in {"TABLE", "VIEW"}
636
637    if has_schema and is_partitionable:
638        prop = expression.find(exp.PartitionedByProperty)
639        if prop and prop.this and not isinstance(prop.this, exp.Schema):
640            schema = expression.this
641            columns = {v.name.upper() for v in prop.this.expressions}
642            partitions = [col for col in schema.expressions if col.name.upper() in columns]
643            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
644            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
645            expression.set("this", schema)
646
647    return expression
648
649
650def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
651    """
652    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
653
654    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
655    """
656    assert isinstance(expression, exp.Create)
657    prop = expression.find(exp.PartitionedByProperty)
658    if (
659        prop
660        and prop.this
661        and isinstance(prop.this, exp.Schema)
662        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
663    ):
664        prop_this = exp.Tuple(
665            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
666        )
667        schema = expression.this
668        for e in prop.this.expressions:
669            schema.append("expressions", e)
670        prop.set("this", prop_this)
671
672    return expression
673
674
675def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
676    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
677    if isinstance(expression, exp.Struct):
678        expression.set(
679            "expressions",
680            [
681                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
682                for e in expression.expressions
683            ],
684        )
685
686    return expression
687
688
689def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
690    """
691    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
692    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
693
694    For example,
695        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
696        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
697
698    Args:
699        expression: The AST to remove join marks from.
700
701    Returns:
702       The AST with join marks removed.
703    """
704    from sqlglot.optimizer.scope import traverse_scope
705
706    for scope in traverse_scope(expression):
707        query = scope.expression
708
709        where = query.args.get("where")
710        joins = query.args.get("joins")
711
712        if not where or not joins:
713            continue
714
715        query_from = query.args["from"]
716
717        # These keep track of the joins to be replaced
718        new_joins: t.Dict[str, exp.Join] = {}
719        old_joins = {join.alias_or_name: join for join in joins}
720
721        for column in scope.columns:
722            if not column.args.get("join_mark"):
723                continue
724
725            predicate = column.find_ancestor(exp.Predicate, exp.Select)
726            assert isinstance(
727                predicate, exp.Binary
728            ), "Columns can only be marked with (+) when involved in a binary operation"
729
730            predicate_parent = predicate.parent
731            join_predicate = predicate.pop()
732
733            left_columns = [
734                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
735            ]
736            right_columns = [
737                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
738            ]
739
740            assert not (
741                left_columns and right_columns
742            ), "The (+) marker cannot appear in both sides of a binary predicate"
743
744            marked_column_tables = set()
745            for col in left_columns or right_columns:
746                table = col.table
747                assert table, f"Column {col} needs to be qualified with a table"
748
749                col.set("join_mark", False)
750                marked_column_tables.add(table)
751
752            assert (
753                len(marked_column_tables) == 1
754            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
755
756            join_this = old_joins.get(col.table, query_from).this
757            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
758
759            # Upsert new_join into new_joins dictionary
760            new_join_alias_or_name = new_join.alias_or_name
761            existing_join = new_joins.get(new_join_alias_or_name)
762            if existing_join:
763                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
764            else:
765                new_joins[new_join_alias_or_name] = new_join
766
767            # If the parent of the target predicate is a binary node, then it now has only one child
768            if isinstance(predicate_parent, exp.Binary):
769                if predicate_parent.left is None:
770                    predicate_parent.replace(predicate_parent.right)
771                else:
772                    predicate_parent.replace(predicate_parent.left)
773
774        if query_from.alias_or_name in new_joins:
775            only_old_joins = old_joins.keys() - new_joins.keys()
776            assert (
777                len(only_old_joins) >= 1
778            ), "Cannot determine which table to use in the new FROM clause"
779
780            new_from_name = list(only_old_joins)[0]
781            query.set("from", exp.From(this=old_joins[new_from_name].this))
782
783        query.set("joins", list(new_joins.values()))
784
785        if not where.this:
786            where.pop()
787
788    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
13def preprocess(
14    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
15) -> t.Callable[[Generator, exp.Expression], str]:
16    """
17    Creates a new transform by chaining a sequence of transformations and converts the resulting
18    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
19    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
20
21    Args:
22        transforms: sequence of transform functions. These will be called in order.
23
24    Returns:
25        Function that can be used as a generator transform.
26    """
27
28    def _to_sql(self, expression: exp.Expression) -> str:
29        expression_type = type(expression)
30
31        expression = transforms[0](expression)
32        for transform in transforms[1:]:
33            expression = transform(expression)
34
35        _sql_handler = getattr(self, expression.key + "_sql", None)
36        if _sql_handler:
37            return _sql_handler(expression)
38
39        transforms_handler = self.TRANSFORMS.get(type(expression))
40        if transforms_handler:
41            if expression_type is type(expression):
42                if isinstance(expression, exp.Func):
43                    return self.function_fallback_sql(expression)
44
45                # Ensures we don't enter an infinite loop. This can happen when the original expression
46                # has the same type as the final expression and there's no _sql method available for it,
47                # because then it'd re-enter _to_sql.
48                raise ValueError(
49                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
50                )
51
52            return transforms_handler(self, expression)
53
54        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
55
56    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
59def unalias_group(expression: exp.Expression) -> exp.Expression:
60    """
61    Replace references to select aliases in GROUP BY clauses.
62
63    Example:
64        >>> import sqlglot
65        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
66        'SELECT a AS b FROM x GROUP BY 1'
67
68    Args:
69        expression: the expression that will be transformed.
70
71    Returns:
72        The transformed expression.
73    """
74    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
75        aliased_selects = {
76            e.alias: i
77            for i, e in enumerate(expression.parent.expressions, start=1)
78            if isinstance(e, exp.Alias)
79        }
80
81        for group_by in expression.expressions:
82            if (
83                isinstance(group_by, exp.Column)
84                and not group_by.table
85                and group_by.name in aliased_selects
86            ):
87                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
88
89    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 92def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 93    """
 94    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 95
 96    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 97
 98    Args:
 99        expression: the expression that will be transformed.
100
101    Returns:
102        The transformed expression.
103    """
104    if (
105        isinstance(expression, exp.Select)
106        and expression.args.get("distinct")
107        and expression.args["distinct"].args.get("on")
108        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
109    ):
110        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
111        outer_selects = expression.selects
112        row_number = find_new_name(expression.named_selects, "_row_number")
113        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
114        order = expression.args.get("order")
115
116        if order:
117            window.set("order", order.pop())
118        else:
119            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
120
121        window = exp.alias_(window, row_number)
122        expression.select(window, copy=False)
123
124        return (
125            exp.select(*outer_selects, copy=False)
126            .from_(expression.subquery("_t", copy=False), copy=False)
127            .where(exp.column(row_number).eq(1), copy=False)
128        )
129
130    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
133def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
134    """
135    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
136
137    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
138    https://docs.snowflake.com/en/sql-reference/constructs/qualify
139
140    Some dialects don't support window functions in the WHERE clause, so we need to include them as
141    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
142    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
143    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
144    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
145    corresponding expression to avoid creating invalid column references.
146    """
147    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
148        taken = set(expression.named_selects)
149        for select in expression.selects:
150            if not select.alias_or_name:
151                alias = find_new_name(taken, "_c")
152                select.replace(exp.alias_(select, alias))
153                taken.add(alias)
154
155        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
156            alias_or_name = select.alias_or_name
157            identifier = select.args.get("alias") or select.this
158            if isinstance(identifier, exp.Identifier):
159                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
160            return alias_or_name
161
162        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
163        qualify_filters = expression.args["qualify"].pop().this
164        expression_by_alias = {
165            select.alias: select.this
166            for select in expression.selects
167            if isinstance(select, exp.Alias)
168        }
169
170        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
171        for select_candidate in qualify_filters.find_all(select_candidates):
172            if isinstance(select_candidate, exp.Window):
173                if expression_by_alias:
174                    for column in select_candidate.find_all(exp.Column):
175                        expr = expression_by_alias.get(column.name)
176                        if expr:
177                            column.replace(expr)
178
179                alias = find_new_name(expression.named_selects, "_w")
180                expression.select(exp.alias_(select_candidate, alias), copy=False)
181                column = exp.column(alias)
182
183                if isinstance(select_candidate.parent, exp.Qualify):
184                    qualify_filters = column
185                else:
186                    select_candidate.replace(column)
187            elif select_candidate.name not in expression.named_selects:
188                expression.select(select_candidate.copy(), copy=False)
189
190        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
191            qualify_filters, copy=False
192        )
193
194    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
197def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
198    """
199    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
200    other expressions. This transforms removes the precision from parameterized types in expressions.
201    """
202    for node in expression.find_all(exp.DataType):
203        node.set(
204            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
205        )
206
207    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
210def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
211    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
212    from sqlglot.optimizer.scope import find_all_in_scope
213
214    if isinstance(expression, exp.Select):
215        unnest_aliases = {
216            unnest.alias
217            for unnest in find_all_in_scope(expression, exp.Unnest)
218            if isinstance(unnest.parent, (exp.From, exp.Join))
219        }
220        if unnest_aliases:
221            for column in expression.find_all(exp.Column):
222                if column.table in unnest_aliases:
223                    column.set("table", None)
224                elif column.db in unnest_aliases:
225                    column.set("db", None)
226
227    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
230def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
231    """Convert cross join unnest into lateral view explode."""
232    if isinstance(expression, exp.Select):
233        from_ = expression.args.get("from")
234
235        if from_ and isinstance(from_.this, exp.Unnest):
236            unnest = from_.this
237            alias = unnest.args.get("alias")
238            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
239            this, *expressions = unnest.expressions
240            unnest.replace(
241                exp.Table(
242                    this=udtf(
243                        this=this,
244                        expressions=expressions,
245                    ),
246                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
247                )
248            )
249
250        for join in expression.args.get("joins") or []:
251            unnest = join.this
252
253            if isinstance(unnest, exp.Unnest):
254                alias = unnest.args.get("alias")
255                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
256
257                expression.args["joins"].remove(join)
258
259                for e, column in zip(unnest.expressions, alias.columns if alias else []):
260                    expression.append(
261                        "laterals",
262                        exp.Lateral(
263                            this=udtf(this=e),
264                            view=True,
265                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
266                        ),
267                    )
268
269    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
272def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
273    """Convert explode/posexplode into unnest."""
274
275    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
276        if isinstance(expression, exp.Select):
277            from sqlglot.optimizer.scope import Scope
278
279            taken_select_names = set(expression.named_selects)
280            taken_source_names = {name for name, _ in Scope(expression).references}
281
282            def new_name(names: t.Set[str], name: str) -> str:
283                name = find_new_name(names, name)
284                names.add(name)
285                return name
286
287            arrays: t.List[exp.Condition] = []
288            series_alias = new_name(taken_select_names, "pos")
289            series = exp.alias_(
290                exp.Unnest(
291                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
292                ),
293                new_name(taken_source_names, "_u"),
294                table=[series_alias],
295            )
296
297            # we use list here because expression.selects is mutated inside the loop
298            for select in list(expression.selects):
299                explode = select.find(exp.Explode)
300
301                if explode:
302                    pos_alias = ""
303                    explode_alias = ""
304
305                    if isinstance(select, exp.Alias):
306                        explode_alias = select.args["alias"]
307                        alias = select
308                    elif isinstance(select, exp.Aliases):
309                        pos_alias = select.aliases[0]
310                        explode_alias = select.aliases[1]
311                        alias = select.replace(exp.alias_(select.this, "", copy=False))
312                    else:
313                        alias = select.replace(exp.alias_(select, ""))
314                        explode = alias.find(exp.Explode)
315                        assert explode
316
317                    is_posexplode = isinstance(explode, exp.Posexplode)
318                    explode_arg = explode.this
319
320                    if isinstance(explode, exp.ExplodeOuter):
321                        bracket = explode_arg[0]
322                        bracket.set("safe", True)
323                        bracket.set("offset", True)
324                        explode_arg = exp.func(
325                            "IF",
326                            exp.func(
327                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
328                            ).eq(0),
329                            exp.array(bracket, copy=False),
330                            explode_arg,
331                        )
332
333                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
334                    if isinstance(explode_arg, exp.Column):
335                        taken_select_names.add(explode_arg.output_name)
336
337                    unnest_source_alias = new_name(taken_source_names, "_u")
338
339                    if not explode_alias:
340                        explode_alias = new_name(taken_select_names, "col")
341
342                        if is_posexplode:
343                            pos_alias = new_name(taken_select_names, "pos")
344
345                    if not pos_alias:
346                        pos_alias = new_name(taken_select_names, "pos")
347
348                    alias.set("alias", exp.to_identifier(explode_alias))
349
350                    series_table_alias = series.args["alias"].this
351                    column = exp.If(
352                        this=exp.column(series_alias, table=series_table_alias).eq(
353                            exp.column(pos_alias, table=unnest_source_alias)
354                        ),
355                        true=exp.column(explode_alias, table=unnest_source_alias),
356                    )
357
358                    explode.replace(column)
359
360                    if is_posexplode:
361                        expressions = expression.expressions
362                        expressions.insert(
363                            expressions.index(alias) + 1,
364                            exp.If(
365                                this=exp.column(series_alias, table=series_table_alias).eq(
366                                    exp.column(pos_alias, table=unnest_source_alias)
367                                ),
368                                true=exp.column(pos_alias, table=unnest_source_alias),
369                            ).as_(pos_alias),
370                        )
371                        expression.set("expressions", expressions)
372
373                    if not arrays:
374                        if expression.args.get("from"):
375                            expression.join(series, copy=False, join_type="CROSS")
376                        else:
377                            expression.from_(series, copy=False)
378
379                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
380                    arrays.append(size)
381
382                    # trino doesn't support left join unnest with on conditions
383                    # if it did, this would be much simpler
384                    expression.join(
385                        exp.alias_(
386                            exp.Unnest(
387                                expressions=[explode_arg.copy()],
388                                offset=exp.to_identifier(pos_alias),
389                            ),
390                            unnest_source_alias,
391                            table=[explode_alias],
392                        ),
393                        join_type="CROSS",
394                        copy=False,
395                    )
396
397                    if index_offset != 1:
398                        size = size - 1
399
400                    expression.where(
401                        exp.column(series_alias, table=series_table_alias)
402                        .eq(exp.column(pos_alias, table=unnest_source_alias))
403                        .or_(
404                            (exp.column(series_alias, table=series_table_alias) > size).and_(
405                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
406                            )
407                        ),
408                        copy=False,
409                    )
410
411            if arrays:
412                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
413
414                if index_offset != 1:
415                    end = end - (1 - index_offset)
416                series.expressions[0].set("end", end)
417
418        return expression
419
420    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
423def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
424    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
425    if (
426        isinstance(expression, exp.PERCENTILES)
427        and not isinstance(expression.parent, exp.WithinGroup)
428        and expression.expression
429    ):
430        column = expression.this.pop()
431        expression.set("this", expression.expression.pop())
432        order = exp.Order(expressions=[exp.Ordered(this=column)])
433        expression = exp.WithinGroup(this=expression, expression=order)
434
435    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
438def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
439    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
440    if (
441        isinstance(expression, exp.WithinGroup)
442        and isinstance(expression.this, exp.PERCENTILES)
443        and isinstance(expression.expression, exp.Order)
444    ):
445        quantile = expression.this.this
446        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
447        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
448
449    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
452def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
453    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
454    if isinstance(expression, exp.With) and expression.recursive:
455        next_name = name_sequence("_c_")
456
457        for cte in expression.expressions:
458            if not cte.args["alias"].columns:
459                query = cte.this
460                if isinstance(query, exp.SetOperation):
461                    query = query.this
462
463                cte.args["alias"].set(
464                    "columns",
465                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
466                )
467
468    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
471def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
472    """Replace 'epoch' in casts by the equivalent date literal."""
473    if (
474        isinstance(expression, (exp.Cast, exp.TryCast))
475        and expression.name.lower() == "epoch"
476        and expression.to.this in exp.DataType.TEMPORAL_TYPES
477    ):
478        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
479
480    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
483def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
484    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
485    if isinstance(expression, exp.Select):
486        for join in expression.args.get("joins") or []:
487            on = join.args.get("on")
488            if on and join.kind in ("SEMI", "ANTI"):
489                subquery = exp.select("1").from_(join.this).where(on)
490                exists = exp.Exists(this=subquery)
491                if join.kind == "ANTI":
492                    exists = exists.not_(copy=False)
493
494                join.pop()
495                expression.where(exists, copy=False)
496
497    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
500def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
501    """
502    Converts a query with a FULL OUTER join to a union of identical queries that
503    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
504    for queries that have a single FULL OUTER join.
505    """
506    if isinstance(expression, exp.Select):
507        full_outer_joins = [
508            (index, join)
509            for index, join in enumerate(expression.args.get("joins") or [])
510            if join.side == "FULL"
511        ]
512
513        if len(full_outer_joins) == 1:
514            expression_copy = expression.copy()
515            expression.set("limit", None)
516            index, full_outer_join = full_outer_joins[0]
517            full_outer_join.set("side", "left")
518            expression_copy.args["joins"][index].set("side", "right")
519            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
520
521            return exp.union(expression, expression_copy, copy=False)
522
523    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
526def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
527    """
528    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
529    defined at the top-level, so for example queries like:
530
531        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
532
533    are invalid in those dialects. This transformation can be used to ensure all CTEs are
534    moved to the top level so that the final SQL code is valid from a syntax standpoint.
535
536    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
537    """
538    top_level_with = expression.args.get("with")
539    for inner_with in expression.find_all(exp.With):
540        if inner_with.parent is expression:
541            continue
542
543        if not top_level_with:
544            top_level_with = inner_with.pop()
545            expression.set("with", top_level_with)
546        else:
547            if inner_with.recursive:
548                top_level_with.set("recursive", True)
549
550            parent_cte = inner_with.find_ancestor(exp.CTE)
551            inner_with.pop()
552
553            if parent_cte:
554                i = top_level_with.expressions.index(parent_cte)
555                top_level_with.expressions[i:i] = inner_with.expressions
556                top_level_with.set("expressions", top_level_with.expressions)
557            else:
558                top_level_with.set(
559                    "expressions", top_level_with.expressions + inner_with.expressions
560                )
561
562    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
565def ensure_bools(expression: exp.Expression) -> exp.Expression:
566    """Converts numeric values used in conditions into explicit boolean expressions."""
567    from sqlglot.optimizer.canonicalize import ensure_bools
568
569    def _ensure_bool(node: exp.Expression) -> None:
570        if (
571            node.is_number
572            or (
573                not isinstance(node, exp.SubqueryPredicate)
574                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
575            )
576            or (isinstance(node, exp.Column) and not node.type)
577        ):
578            node.replace(node.neq(0))
579
580    for node in expression.walk():
581        ensure_bools(node, _ensure_bool)
582
583    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
586def unqualify_columns(expression: exp.Expression) -> exp.Expression:
587    for column in expression.find_all(exp.Column):
588        # We only wanna pop off the table, db, catalog args
589        for part in column.parts[:-1]:
590            part.pop()
591
592    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
595def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
596    assert isinstance(expression, exp.Create)
597    for constraint in expression.find_all(exp.UniqueColumnConstraint):
598        if constraint.parent:
599            constraint.parent.pop()
600
601    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
604def ctas_with_tmp_tables_to_create_tmp_view(
605    expression: exp.Expression,
606    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
607) -> exp.Expression:
608    assert isinstance(expression, exp.Create)
609    properties = expression.args.get("properties")
610    temporary = any(
611        isinstance(prop, exp.TemporaryProperty)
612        for prop in (properties.expressions if properties else [])
613    )
614
615    # CTAS with temp tables map to CREATE TEMPORARY VIEW
616    if expression.kind == "TABLE" and temporary:
617        if expression.expression:
618            return exp.Create(
619                kind="TEMPORARY VIEW",
620                this=expression.this,
621                expression=expression.expression,
622            )
623        return tmp_storage_provider(expression)
624
625    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
628def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
629    """
630    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
631    PARTITIONED BY value is an array of column names, they are transformed into a schema.
632    The corresponding columns are removed from the create statement.
633    """
634    assert isinstance(expression, exp.Create)
635    has_schema = isinstance(expression.this, exp.Schema)
636    is_partitionable = expression.kind in {"TABLE", "VIEW"}
637
638    if has_schema and is_partitionable:
639        prop = expression.find(exp.PartitionedByProperty)
640        if prop and prop.this and not isinstance(prop.this, exp.Schema):
641            schema = expression.this
642            columns = {v.name.upper() for v in prop.this.expressions}
643            partitions = [col for col in schema.expressions if col.name.upper() in columns]
644            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
645            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
646            expression.set("this", schema)
647
648    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
651def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
652    """
653    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
654
655    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
656    """
657    assert isinstance(expression, exp.Create)
658    prop = expression.find(exp.PartitionedByProperty)
659    if (
660        prop
661        and prop.this
662        and isinstance(prop.this, exp.Schema)
663        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
664    ):
665        prop_this = exp.Tuple(
666            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
667        )
668        schema = expression.this
669        for e in prop.this.expressions:
670            schema.append("expressions", e)
671        prop.set("this", prop_this)
672
673    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
676def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
677    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
678    if isinstance(expression, exp.Struct):
679        expression.set(
680            "expressions",
681            [
682                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
683                for e in expression.expressions
684            ],
685        )
686
687    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
690def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
691    """
692    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
693    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
694
695    For example,
696        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
697        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
698
699    Args:
700        expression: The AST to remove join marks from.
701
702    Returns:
703       The AST with join marks removed.
704    """
705    from sqlglot.optimizer.scope import traverse_scope
706
707    for scope in traverse_scope(expression):
708        query = scope.expression
709
710        where = query.args.get("where")
711        joins = query.args.get("joins")
712
713        if not where or not joins:
714            continue
715
716        query_from = query.args["from"]
717
718        # These keep track of the joins to be replaced
719        new_joins: t.Dict[str, exp.Join] = {}
720        old_joins = {join.alias_or_name: join for join in joins}
721
722        for column in scope.columns:
723            if not column.args.get("join_mark"):
724                continue
725
726            predicate = column.find_ancestor(exp.Predicate, exp.Select)
727            assert isinstance(
728                predicate, exp.Binary
729            ), "Columns can only be marked with (+) when involved in a binary operation"
730
731            predicate_parent = predicate.parent
732            join_predicate = predicate.pop()
733
734            left_columns = [
735                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
736            ]
737            right_columns = [
738                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
739            ]
740
741            assert not (
742                left_columns and right_columns
743            ), "The (+) marker cannot appear in both sides of a binary predicate"
744
745            marked_column_tables = set()
746            for col in left_columns or right_columns:
747                table = col.table
748                assert table, f"Column {col} needs to be qualified with a table"
749
750                col.set("join_mark", False)
751                marked_column_tables.add(table)
752
753            assert (
754                len(marked_column_tables) == 1
755            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
756
757            join_this = old_joins.get(col.table, query_from).this
758            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
759
760            # Upsert new_join into new_joins dictionary
761            new_join_alias_or_name = new_join.alias_or_name
762            existing_join = new_joins.get(new_join_alias_or_name)
763            if existing_join:
764                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
765            else:
766                new_joins[new_join_alias_or_name] = new_join
767
768            # If the parent of the target predicate is a binary node, then it now has only one child
769            if isinstance(predicate_parent, exp.Binary):
770                if predicate_parent.left is None:
771                    predicate_parent.replace(predicate_parent.right)
772                else:
773                    predicate_parent.replace(predicate_parent.left)
774
775        if query_from.alias_or_name in new_joins:
776            only_old_joins = old_joins.keys() - new_joins.keys()
777            assert (
778                len(only_old_joins) >= 1
779            ), "Cannot determine which table to use in the new FROM clause"
780
781            new_from_name = list(only_old_joins)[0]
782            query.set("from", exp.From(this=old_joins[new_from_name].this))
783
784        query.set("joins", list(new_joins.values()))
785
786        if not where.this:
787            where.pop()
788
789    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.