Edit on GitHub

sqlglot.transforms

   1from __future__ import annotations
   2
   3import typing as t
   4
   5from sqlglot import expressions as exp
   6from sqlglot.errors import UnsupportedError
   7from sqlglot.helper import find_new_name, name_sequence, seq_get
   8
   9
  10if t.TYPE_CHECKING:
  11    from sqlglot._typing import E
  12    from sqlglot.generator import Generator
  13
  14
  15def preprocess(
  16    transforms: list[t.Callable[[exp.Expr], exp.Expr]],
  17    generator: t.Callable[[Generator, exp.Expr], str] | None = None,
  18) -> t.Callable[[Generator, exp.Expr], str]:
  19    """
  20    Creates a new transform by chaining a sequence of transformations and converts the resulting
  21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
  22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
  23
  24    Args:
  25        transforms: sequence of transform functions. These will be called in order.
  26
  27    Returns:
  28        Function that can be used as a generator transform.
  29    """
  30
  31    def _to_sql(self, expression: exp.Expr) -> str:
  32        expression_type = type(expression)
  33
  34        try:
  35            expression = transforms[0](expression)
  36            for transform in transforms[1:]:
  37                expression = transform(expression)
  38        except UnsupportedError as unsupported_error:
  39            self.unsupported(str(unsupported_error))
  40
  41        if generator:
  42            return generator(self, expression)
  43
  44        _sql_handler = getattr(self, expression.key + "_sql", None)
  45        if _sql_handler:
  46            return _sql_handler(expression)
  47
  48        transforms_handler = self.TRANSFORMS.get(type(expression))
  49        if transforms_handler:
  50            if expression_type is type(expression):
  51                if isinstance(expression, exp.Func):
  52                    return self.function_fallback_sql(expression)
  53
  54                # Ensures we don't enter an infinite loop. This can happen when the original expression
  55                # has the same type as the final expression and there's no _sql method available for it,
  56                # because then it'd re-enter _to_sql.
  57                raise ValueError(
  58                    f"Expr type {expression.__class__.__name__} requires a _sql method in order to be transformed."
  59                )
  60
  61            return transforms_handler(self, expression)
  62
  63        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
  64
  65    return _to_sql
  66
  67
  68def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp.Expr:
  69    if isinstance(expression, exp.Select):
  70        count = 0
  71        recursive_ctes = []
  72
  73        for unnest in expression.find_all(exp.Unnest):
  74            if (
  75                not isinstance(unnest.parent, (exp.From, exp.Join))
  76                or len(unnest.expressions) != 1
  77                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
  78            ):
  79                continue
  80
  81            generate_date_array = unnest.expressions[0]
  82            start = generate_date_array.args.get("start")
  83            end = generate_date_array.args.get("end")
  84            step = generate_date_array.args.get("step")
  85
  86            if not start or not end or not isinstance(step, exp.Interval):
  87                continue
  88
  89            alias = unnest.args.get("alias")
  90            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
  91
  92            start = exp.cast(start, "date")
  93            date_add = exp.func(
  94                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
  95            )
  96            cast_date_add = exp.cast(date_add, "date")
  97
  98            cte_name = "_generated_dates" + (f"_{count}" if count else "")
  99
 100            base_query = exp.select(start.as_(column_name))
 101            recursive_query = (
 102                exp.select(cast_date_add)
 103                .from_(cte_name)
 104                .where(cast_date_add <= exp.cast(end, "date"))
 105            )
 106            cte_query = base_query.union(recursive_query, distinct=False)
 107
 108            generate_dates_query = exp.select(column_name).from_(cte_name)
 109            unnest.replace(generate_dates_query.subquery(cte_name))
 110
 111            recursive_ctes.append(
 112                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
 113            )
 114            count += 1
 115
 116        if recursive_ctes:
 117            with_expression = expression.args.get("with_") or exp.With()
 118            with_expression.set("recursive", True)
 119            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
 120            expression.set("with_", with_expression)
 121
 122    return expression
 123
 124
 125def unnest_generate_series(expression: exp.Expr) -> exp.Expr:
 126    """Unnests GENERATE_SERIES or SEQUENCE table references."""
 127    this = expression.this
 128    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
 129        unnest = exp.Unnest(expressions=[this])
 130        if expression.alias:
 131            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
 132
 133        return unnest
 134
 135    return expression
 136
 137
 138def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr:
 139    """
 140    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 141
 142    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 143
 144    Args:
 145        expression: the expression that will be transformed.
 146
 147    Returns:
 148        The transformed expression.
 149    """
 150    if (
 151        isinstance(expression, exp.Select)
 152        and expression.args.get("distinct")
 153        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
 154    ):
 155        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
 156
 157        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 158        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 159
 160        order = expression.args.get("order")
 161        if order:
 162            window.set("order", order.pop())
 163        else:
 164            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 165
 166        expression.select(exp.alias_(window, row_number_window_alias), copy=False)
 167
 168        # We add aliases to the projections so that we can safely reference them in the outer query
 169        new_selects = []
 170        taken_names = {row_number_window_alias}
 171        for select in expression.selects[:-1]:
 172            if select.is_star:
 173                new_selects = [exp.Star()]
 174                break
 175
 176            if not isinstance(select, exp.Alias):
 177                alias = find_new_name(taken_names, select.output_name or "_col")
 178                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
 179                select = select.replace(exp.alias_(select, alias, quoted=quoted))
 180
 181            taken_names.add(select.output_name)
 182            new_selects.append(select.args["alias"])
 183
 184        return (
 185            exp.select(*new_selects, copy=False)
 186            .from_(expression.subquery("_t", copy=False), copy=False)
 187            .where(exp.column(row_number_window_alias).eq(1), copy=False)
 188        )
 189
 190    return expression
 191
 192
 193def eliminate_qualify(expression: exp.Expr) -> exp.Expr:
 194    """
 195    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 196
 197    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 198    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 199
 200    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 201    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 202    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 203    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 204    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 205    corresponding expression to avoid creating invalid column references.
 206    """
 207    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 208        taken = set(expression.named_selects)
 209        for select in expression.selects:
 210            if not select.alias_or_name:
 211                alias = find_new_name(taken, "_c")
 212                select.replace(exp.alias_(select, alias))
 213                taken.add(alias)
 214
 215        def _select_alias_or_name(select: exp.Expr) -> str | exp.Column:
 216            alias_or_name = select.alias_or_name
 217            identifier = select.args.get("alias") or select.this
 218            if isinstance(identifier, exp.Identifier):
 219                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
 220            return alias_or_name
 221
 222        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
 223        qualify_filters = expression.args["qualify"].pop().this
 224        expression_by_alias = {
 225            select.alias: select.this
 226            for select in expression.selects
 227            if isinstance(select, exp.Alias)
 228        }
 229
 230        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
 231        for select_candidate in list(qualify_filters.find_all(select_candidates)):
 232            if isinstance(select_candidate, exp.Window):
 233                if expression_by_alias:
 234                    for column in select_candidate.find_all(exp.Column):
 235                        expr = expression_by_alias.get(column.name)
 236                        if expr:
 237                            column.replace(expr)
 238
 239                alias = find_new_name(expression.named_selects, "_w")
 240                expression.select(exp.alias_(select_candidate, alias), copy=False)
 241                column = exp.column(alias)
 242
 243                if isinstance(select_candidate.parent, exp.Qualify):
 244                    qualify_filters = column
 245                else:
 246                    select_candidate.replace(column)
 247            elif select_candidate.name not in expression.named_selects:
 248                expression.select(select_candidate.copy(), copy=False)
 249
 250        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
 251            qualify_filters, copy=False
 252        )
 253
 254    return expression
 255
 256
 257def remove_precision_parameterized_types(expression: exp.Expr) -> exp.Expr:
 258    """
 259    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
 260    other expressions. This transforms removes the precision from parameterized types in expressions.
 261    """
 262    for node in expression.find_all(exp.DataType):
 263        node.set(
 264            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
 265        )
 266
 267    return expression
 268
 269
 270def unqualify_unnest(expression: exp.Expr) -> exp.Expr:
 271    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
 272    from sqlglot.optimizer.scope import find_all_in_scope
 273
 274    if isinstance(expression, exp.Select):
 275        unnest_aliases = {
 276            unnest.alias
 277            for unnest in find_all_in_scope(expression, exp.Unnest)
 278            if isinstance(unnest.parent, (exp.From, exp.Join))
 279        }
 280        if unnest_aliases:
 281            for column in expression.find_all(exp.Column):
 282                leftmost_part = column.parts[0]
 283                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
 284                    leftmost_part.pop()
 285
 286    return expression
 287
 288
 289def unnest_to_explode(
 290    expression: exp.Expr,
 291    unnest_using_arrays_zip: bool = True,
 292) -> exp.Expr:
 293    """Convert cross join unnest into lateral view explode."""
 294
 295    def _unnest_zip_exprs(
 296        u: exp.Unnest, unnest_exprs: list[exp.Expr], has_multi_expr: bool
 297    ) -> list[exp.Expr]:
 298        if has_multi_expr:
 299            if not unnest_using_arrays_zip:
 300                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
 301
 302            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
 303            zip_exprs: list[exp.Expr] = [exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)]
 304            u.set("expressions", zip_exprs)
 305            return zip_exprs
 306        return unnest_exprs
 307
 308    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]:
 309        if u.args.get("offset"):
 310            return exp.Posexplode
 311        return exp.Inline if has_multi_expr else exp.Explode
 312
 313    if isinstance(expression, exp.Select):
 314        from_ = expression.args.get("from_")
 315
 316        if from_ and isinstance(from_.this, exp.Unnest):
 317            unnest = from_.this
 318            alias = unnest.args.get("alias")
 319            exprs = unnest.expressions
 320            has_multi_expr = len(exprs) > 1
 321            this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 322
 323            columns = alias.columns if alias else []
 324            offset = unnest.args.get("offset")
 325            if offset:
 326                columns.insert(
 327                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
 328                )
 329
 330            unnest.replace(
 331                exp.Table(
 332                    this=_udtf_type(unnest, has_multi_expr)(this=this),
 333                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
 334                )
 335            )
 336
 337        joins = expression.args.get("joins") or []
 338        for join in list(joins):
 339            join_expr = join.this
 340
 341            is_lateral = isinstance(join_expr, exp.Lateral)
 342
 343            unnest = join_expr.this if is_lateral else join_expr
 344
 345            if isinstance(unnest, exp.Unnest):
 346                if is_lateral:
 347                    alias = join_expr.args.get("alias")
 348                else:
 349                    alias = unnest.args.get("alias")
 350                exprs = unnest.expressions
 351                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
 352                has_multi_expr = len(exprs) > 1
 353                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 354
 355                joins.remove(join)
 356
 357                alias_cols = alias.columns if alias else []
 358
 359                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
 360                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
 361                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
 362
 363                if not has_multi_expr and len(alias_cols) not in (1, 2):
 364                    raise UnsupportedError(
 365                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
 366                    )
 367
 368                offset = unnest.args.get("offset")
 369                if offset:
 370                    alias_cols.insert(
 371                        0,
 372                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
 373                    )
 374
 375                for e, column in zip(exprs, alias_cols):
 376                    expression.append(
 377                        "laterals",
 378                        exp.Lateral(
 379                            this=_udtf_type(unnest, has_multi_expr)(this=e),
 380                            view=True,
 381                            alias=exp.TableAlias(
 382                                this=alias.this,  # type: ignore
 383                                columns=alias_cols,
 384                            ),
 385                        ),
 386                    )
 387
 388    return expression
 389
 390
 391def explode_projection_to_unnest(
 392    index_offset: int = 0,
 393) -> t.Callable[[exp.Expr], exp.Expr]:
 394    """Convert explode/posexplode projections into unnests."""
 395
 396    def _explode_projection_to_unnest(expression: exp.Expr) -> exp.Expr:
 397        if isinstance(expression, exp.Select):
 398            from sqlglot.optimizer.scope import Scope
 399
 400            taken_select_names = set(expression.named_selects)
 401            taken_source_names = {name for name, _ in Scope(expression).references}
 402
 403            def new_name(names: set[str], name: str) -> str:
 404                name = find_new_name(names, name)
 405                names.add(name)
 406                return name
 407
 408            arrays: list[exp.Condition] = []
 409            series_alias = new_name(taken_select_names, "pos")
 410            series = exp.alias_(
 411                exp.Unnest(
 412                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
 413                ),
 414                new_name(taken_source_names, "_u"),
 415                table=[series_alias],
 416            )
 417
 418            # we use list here because expression.selects is mutated inside the loop
 419            for select in list(expression.selects):
 420                explode = select.find(exp.Explode)
 421
 422                if explode:
 423                    pos_alias: t.Any = ""
 424                    explode_alias: t.Any = ""
 425
 426                    if isinstance(select, exp.Alias):
 427                        explode_alias = select.args["alias"]
 428                        alias: exp.Expr = select
 429                    elif isinstance(select, exp.Aliases):
 430                        pos_alias = select.aliases[0]
 431                        explode_alias = select.aliases[1]
 432                        alias = select.replace(exp.alias_(select.this, "", copy=False))
 433                    else:
 434                        alias = select.replace(exp.alias_(select, ""))
 435                        explode = alias.find(exp.Explode)
 436                        assert explode
 437
 438                    is_posexplode = isinstance(explode, exp.Posexplode)
 439                    explode_arg = explode.this
 440
 441                    if isinstance(explode, exp.ExplodeOuter):
 442                        bracket = explode_arg[0]
 443                        bracket.set("safe", True)
 444                        bracket.set("offset", True)
 445                        explode_arg = exp.func(
 446                            "IF",
 447                            exp.func(
 448                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
 449                            ).eq(0),
 450                            exp.array(bracket, copy=False),
 451                            explode_arg,
 452                        )
 453
 454                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
 455                    if isinstance(explode_arg, exp.Column):
 456                        taken_select_names.add(explode_arg.output_name)
 457
 458                    unnest_source_alias = new_name(taken_source_names, "_u")
 459
 460                    if not explode_alias:
 461                        explode_alias = new_name(taken_select_names, "col")
 462
 463                        if is_posexplode:
 464                            pos_alias = new_name(taken_select_names, "pos")
 465
 466                    if not pos_alias:
 467                        pos_alias = new_name(taken_select_names, "pos")
 468
 469                    alias.set("alias", exp.to_identifier(explode_alias))
 470
 471                    series_table_alias = series.args["alias"].this
 472                    column = exp.If(
 473                        this=exp.column(series_alias, table=series_table_alias).eq(
 474                            exp.column(pos_alias, table=unnest_source_alias)
 475                        ),
 476                        true=exp.column(explode_alias, table=unnest_source_alias),
 477                    )
 478
 479                    explode.replace(column)
 480
 481                    if is_posexplode:
 482                        expressions = expression.expressions
 483                        expressions.insert(
 484                            expressions.index(alias) + 1,
 485                            exp.If(
 486                                this=exp.column(series_alias, table=series_table_alias).eq(
 487                                    exp.column(pos_alias, table=unnest_source_alias)
 488                                ),
 489                                true=exp.column(pos_alias, table=unnest_source_alias),
 490                            ).as_(pos_alias),
 491                        )
 492                        expression.set("expressions", expressions)
 493
 494                    if not arrays:
 495                        if expression.args.get("from_"):
 496                            expression.join(series, copy=False, join_type="CROSS")
 497                        else:
 498                            expression.from_(series, copy=False)
 499
 500                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
 501                    arrays.append(size)
 502
 503                    # trino doesn't support left join unnest with on conditions
 504                    # if it did, this would be much simpler
 505                    expression.join(
 506                        exp.alias_(
 507                            exp.Unnest(
 508                                expressions=[explode_arg.copy()],
 509                                offset=exp.to_identifier(pos_alias),
 510                            ),
 511                            unnest_source_alias,
 512                            table=[explode_alias],
 513                        ),
 514                        join_type="CROSS",
 515                        copy=False,
 516                    )
 517
 518                    if index_offset != 1:
 519                        size = size - 1
 520
 521                    expression.where(
 522                        exp.column(series_alias, table=series_table_alias)
 523                        .eq(exp.column(pos_alias, table=unnest_source_alias))
 524                        .or_(
 525                            (exp.column(series_alias, table=series_table_alias) > size).and_(
 526                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
 527                            )
 528                        ),
 529                        copy=False,
 530                    )
 531
 532            if arrays:
 533                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
 534
 535                if index_offset != 1:
 536                    end = end - (1 - index_offset)
 537                series.expressions[0].set("end", end)
 538
 539        return expression
 540
 541    return _explode_projection_to_unnest
 542
 543
 544def add_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
 545    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
 546    if (
 547        isinstance(expression, exp.PERCENTILES)
 548        and not isinstance(expression.parent, exp.WithinGroup)
 549        and expression.expression
 550    ):
 551        column = expression.this.pop()
 552        expression.set("this", expression.expression.pop())
 553        order = exp.Order(expressions=[exp.Ordered(this=column)])
 554        expression = exp.WithinGroup(this=expression, expression=order)
 555
 556    return expression
 557
 558
 559def remove_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
 560    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
 561    if (
 562        isinstance(expression, exp.WithinGroup)
 563        and isinstance(expression.this, exp.PERCENTILES)
 564        and isinstance(expression.expression, exp.Order)
 565    ):
 566        quantile = expression.this.this
 567        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
 568        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
 569
 570    return expression
 571
 572
 573def add_recursive_cte_column_names(expression: exp.Expr) -> exp.Expr:
 574    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
 575    if isinstance(expression, exp.With) and expression.recursive:
 576        next_name = name_sequence("_c_")
 577
 578        for cte in expression.expressions:
 579            if not cte.args["alias"].columns:
 580                query = cte.this
 581                if isinstance(query, exp.SetOperation):
 582                    query = query.this
 583
 584                cte.args["alias"].set(
 585                    "columns",
 586                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
 587                )
 588
 589    return expression
 590
 591
 592def epoch_cast_to_ts(expression: exp.Expr) -> exp.Expr:
 593    """Replace 'epoch' in casts by the equivalent date literal."""
 594    if (
 595        isinstance(expression, (exp.Cast, exp.TryCast))
 596        and expression.name.lower() == "epoch"
 597        and expression.to.this in exp.DataType.TEMPORAL_TYPES
 598    ):
 599        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
 600
 601    return expression
 602
 603
 604def eliminate_semi_and_anti_joins(expression: exp.Expr) -> exp.Expr:
 605    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
 606    if isinstance(expression, exp.Select):
 607        for join in list(expression.args.get("joins") or []):
 608            on = join.args.get("on")
 609            if on and join.kind in ("SEMI", "ANTI"):
 610                subquery = exp.select("1").from_(join.this).where(on)
 611                exists: exp.Exists | exp.Not = exp.Exists(this=subquery)
 612                if join.kind == "ANTI":
 613                    exists = exists.not_(copy=False)
 614
 615                join.pop()
 616                expression.where(exists, copy=False)
 617
 618    return expression
 619
 620
 621def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr:
 622    """
 623    Converts a query with a FULL OUTER join to a union of identical queries that
 624    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
 625    for queries that have a single FULL OUTER join.
 626    """
 627    if isinstance(expression, exp.Select):
 628        full_outer_joins = [
 629            (index, join)
 630            for index, join in enumerate(expression.args.get("joins") or [])
 631            if join.side == "FULL"
 632        ]
 633
 634        if len(full_outer_joins) == 1:
 635            expression_copy = expression.copy()
 636            expression.set("limit", None)
 637            index, full_outer_join = full_outer_joins[0]
 638
 639            tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name)
 640            join_conditions = full_outer_join.args.get("on") or exp.and_(
 641                *[
 642                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
 643                    for col in full_outer_join.args.get("using")
 644                ]
 645            )
 646
 647            full_outer_join.set("side", "left")
 648            anti_join_clause = (
 649                exp.select("1").from_(expression.args["from_"]).where(join_conditions)
 650            )
 651            expression_copy.args["joins"][index].set("side", "right")
 652            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
 653            expression_copy.set("with_", None)  # remove CTEs from RIGHT side
 654            expression.set("order", None)  # remove order by from LEFT side
 655
 656            return exp.union(expression, expression_copy, copy=False, distinct=False)
 657
 658    return expression
 659
 660
 661def move_ctes_to_top_level(expression: E) -> E:
 662    """
 663    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
 664    defined at the top-level, so for example queries like:
 665
 666        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
 667
 668    are invalid in those dialects. This transformation can be used to ensure all CTEs are
 669    moved to the top level so that the final SQL code is valid from a syntax standpoint.
 670
 671    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
 672    """
 673    top_level_with = expression.args.get("with_")
 674    for inner_with in expression.find_all(exp.With):
 675        if inner_with.parent is expression:
 676            continue
 677
 678        if not top_level_with:
 679            top_level_with = inner_with.pop()
 680            expression.set("with_", top_level_with)
 681        else:
 682            if inner_with.recursive:
 683                top_level_with.set("recursive", True)
 684
 685            parent_cte = inner_with.find_ancestor(exp.CTE)
 686            inner_with.pop()
 687
 688            if parent_cte:
 689                i = top_level_with.expressions.index(parent_cte)
 690                top_level_with.expressions[i:i] = inner_with.expressions
 691                top_level_with.set("expressions", top_level_with.expressions)
 692            else:
 693                top_level_with.set(
 694                    "expressions", top_level_with.expressions + inner_with.expressions
 695                )
 696
 697    return expression
 698
 699
 700def ensure_bools(expression: exp.Expr) -> exp.Expr:
 701    """Converts numeric values used in conditions into explicit boolean expressions."""
 702    from sqlglot.optimizer.canonicalize import ensure_bools
 703
 704    def _ensure_bool(node: exp.Expr) -> None:
 705        if (
 706            node.is_number
 707            or (
 708                not isinstance(node, exp.SubqueryPredicate)
 709                and node.is_type(exp.DType.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
 710            )
 711            or (isinstance(node, exp.Column) and not node.type)
 712        ):
 713            node.replace(node.neq(0))
 714
 715    for node in expression.walk():
 716        ensure_bools(node, _ensure_bool)
 717
 718    return expression
 719
 720
 721def unqualify_columns(expression: exp.Expr) -> exp.Expr:
 722    for column in expression.find_all(exp.Column):
 723        # We only wanna pop off the table, db, catalog args
 724        for part in column.parts[:-1]:
 725            part.pop()
 726
 727    return expression
 728
 729
 730def remove_unique_constraints(expression: exp.Expr) -> exp.Expr:
 731    assert isinstance(expression, exp.Create)
 732    for constraint in expression.find_all(exp.UniqueColumnConstraint):
 733        if constraint.parent:
 734            constraint.parent.pop()
 735
 736    return expression
 737
 738
 739def ctas_with_tmp_tables_to_create_tmp_view(
 740    expression: exp.Expr,
 741    tmp_storage_provider: t.Callable[[exp.Expr], exp.Expr] = lambda e: e,
 742) -> exp.Expr:
 743    assert isinstance(expression, exp.Create)
 744    properties = expression.args.get("properties")
 745    temporary = any(
 746        isinstance(prop, exp.TemporaryProperty)
 747        for prop in (properties.expressions if properties else [])
 748    )
 749
 750    # CTAS with temp tables map to CREATE TEMPORARY VIEW
 751    if expression.kind == "TABLE" and temporary:
 752        if expression.expression:
 753            return exp.Create(
 754                kind="TEMPORARY VIEW",
 755                this=expression.this,
 756                expression=expression.expression,
 757            )
 758        return tmp_storage_provider(expression)
 759
 760    return expression
 761
 762
 763def move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
 764    """
 765    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
 766    PARTITIONED BY value is an array of column names, they are transformed into a schema.
 767    The corresponding columns are removed from the create statement.
 768    """
 769    assert isinstance(expression, exp.Create)
 770    has_schema = isinstance(expression.this, exp.Schema)
 771    is_partitionable = expression.kind in {"TABLE", "VIEW"}
 772
 773    if has_schema and is_partitionable:
 774        prop = expression.find(exp.PartitionedByProperty)
 775        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 776            schema = expression.this
 777            columns = {v.name.upper() for v in prop.this.expressions}
 778            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 779            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 780            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 781            expression.set("this", schema)
 782
 783    return expression
 784
 785
 786def move_partitioned_by_to_schema_columns(expression: exp.Expr) -> exp.Expr:
 787    """
 788    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
 789
 790    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
 791    """
 792    assert isinstance(expression, exp.Create)
 793    prop = expression.find(exp.PartitionedByProperty)
 794    if (
 795        prop
 796        and prop.this
 797        and isinstance(prop.this, exp.Schema)
 798        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
 799    ):
 800        prop_this = exp.Tuple(
 801            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
 802        )
 803        schema = expression.this
 804        for e in prop.this.expressions:
 805            schema.append("expressions", e)
 806        prop.set("this", prop_this)
 807
 808    return expression
 809
 810
 811def struct_kv_to_alias(expression: exp.Expr) -> exp.Expr:
 812    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
 813    if isinstance(expression, exp.Struct):
 814        expression.set(
 815            "expressions",
 816            [
 817                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
 818                for e in expression.expressions
 819            ],
 820        )
 821
 822    return expression
 823
 824
 825def eliminate_join_marks(expression: exp.Expr) -> exp.Expr:
 826    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
 827
 828    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
 829
 830    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
 831
 832    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
 833
 834    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
 835
 836    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
 837
 838    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
 839
 840    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
 841
 842    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
 843
 844    -- example with WHERE
 845    SELECT d.department_name, sum(e.salary) as total_salary
 846    FROM departments d, employees e
 847    WHERE e.department_id(+) = d.department_id
 848    group by department_name
 849
 850    -- example of left correlation in select
 851    SELECT d.department_name, (
 852        SELECT SUM(e.salary)
 853            FROM employees e
 854            WHERE e.department_id(+) = d.department_id) AS total_salary
 855    FROM departments d;
 856
 857    -- example of left correlation in from
 858    SELECT d.department_name, t.total_salary
 859    FROM departments d, (
 860            SELECT SUM(e.salary) AS total_salary
 861            FROM employees e
 862            WHERE e.department_id(+) = d.department_id
 863        ) t
 864    """
 865
 866    from sqlglot.optimizer.scope import traverse_scope
 867    from sqlglot.optimizer.normalize import normalize, normalized
 868    from collections import defaultdict
 869
 870    # we go in reverse to check the main query for left correlation
 871    for scope in reversed(traverse_scope(expression)):
 872        query = scope.expression
 873
 874        where = query.args.get("where")
 875        joins = query.args.get("joins", [])
 876
 877        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
 878            continue
 879
 880        # knockout: we do not support left correlation (see point 2)
 881        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
 882
 883        # make sure we have AND of ORs to have clear join terms
 884        where = normalize(where.this)
 885        assert normalized(where), "Cannot normalize JOIN predicates"
 886
 887        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
 888        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
 889            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
 890
 891            left_join_table = set(col.table for col in join_cols)
 892            if not left_join_table:
 893                continue
 894
 895            assert not (len(left_join_table) > 1), (
 896                "Cannot combine JOIN predicates from different tables"
 897            )
 898
 899            for col in join_cols:
 900                col.set("join_mark", False)
 901
 902            joins_ons[left_join_table.pop()].append(cond)
 903
 904        old_joins = {join.alias_or_name: join for join in joins}
 905        new_joins = {}
 906        query_from = query.args["from_"]
 907
 908        for table, predicates in joins_ons.items():
 909            join_what = old_joins.get(table, query_from).this.copy()
 910            new_joins[join_what.alias_or_name] = exp.Join(
 911                this=join_what, on=exp.and_(*predicates), kind="LEFT"
 912            )
 913
 914            for p in predicates:
 915                while isinstance(p.parent, exp.Paren):
 916                    p.parent.replace(p)
 917
 918                parent = p.parent
 919                p.pop()
 920                if isinstance(parent, exp.Binary):
 921                    left = parent.args.get("this")
 922                    parent.replace(parent.right if left is None else left)
 923                elif isinstance(parent, exp.Where):
 924                    parent.pop()
 925
 926        if query_from.alias_or_name in new_joins:
 927            only_old_joins = old_joins.keys() - new_joins.keys()
 928            assert len(only_old_joins) >= 1, (
 929                "Cannot determine which table to use in the new FROM clause"
 930            )
 931
 932            new_from_name = list(only_old_joins)[0]
 933            query.set("from_", exp.From(this=old_joins[new_from_name].this))
 934
 935        if new_joins:
 936            for n, j in old_joins.items():  # preserve any other joins
 937                if n not in new_joins and n != query.args["from_"].name:
 938                    if not j.kind:
 939                        j.set("kind", "CROSS")
 940                    new_joins[n] = j
 941            query.set("joins", list(new_joins.values()))
 942
 943    return expression
 944
 945
 946def any_to_exists(expression: exp.Expr) -> exp.Expr:
 947    """
 948    Transform ANY operator to Spark's EXISTS
 949
 950    For example,
 951        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
 952        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
 953
 954    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
 955    transformation
 956    """
 957    if isinstance(expression, exp.Select):
 958        for any_expr in expression.find_all(exp.Any):
 959            this = any_expr.this
 960            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
 961                continue
 962
 963            binop = any_expr.parent
 964            if isinstance(binop, exp.Binary):
 965                lambda_arg = exp.to_identifier("x")
 966                any_expr.replace(lambda_arg)
 967                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
 968                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
 969
 970    return expression
 971
 972
 973def eliminate_window_clause(expression: exp.Expr) -> exp.Expr:
 974    """Eliminates the `WINDOW` query clause by inling each named window."""
 975    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 976        from sqlglot.optimizer.scope import find_all_in_scope
 977
 978        windows = expression.args["windows"]
 979        expression.set("windows", None)
 980
 981        window_expression: dict[str, exp.Expr] = {}
 982
 983        def _inline_inherited_window(window: exp.Expr) -> None:
 984            inherited_window = window_expression.get(window.alias.lower())
 985            if not inherited_window:
 986                return
 987
 988            window.set("alias", None)
 989            for key in ("partition_by", "order", "spec"):
 990                arg = inherited_window.args.get(key)
 991                if arg:
 992                    window.set(key, arg.copy())
 993
 994        for window in windows:
 995            _inline_inherited_window(window)
 996            window_expression[window.name.lower()] = window
 997
 998        for window in find_all_in_scope(expression, exp.Window):
 999            _inline_inherited_window(window)
1000
1001    return expression
1002
1003
1004def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr:
1005    """
1006    Inherit field names from the first struct in an array.
1007
1008    BigQuery supports implicitly inheriting names from the first STRUCT in an array:
1009
1010    Example:
1011        ARRAY[
1012          STRUCT('Alice' AS name, 85 AS score),  -- defines names
1013          STRUCT('Bob', 92),                     -- inherits names
1014          STRUCT('Diana', 95)                    -- inherits names
1015        ]
1016
1017    This transformation makes the field names explicit on all structs by adding
1018    PropertyEQ nodes, in order to facilitate transpilation to other dialects.
1019
1020    Args:
1021        expression: The expression tree to transform
1022
1023    Returns:
1024        The modified expression with field names inherited in all structs
1025    """
1026    if (
1027        isinstance(expression, exp.Array)
1028        and expression.args.get("struct_name_inheritance")
1029        and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct)
1030        and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions)
1031    ):
1032        field_names = [fld.this for fld in first_item.expressions]
1033
1034        # Apply field names to subsequent structs that don't have them
1035        for struct in expression.expressions[1:]:
1036            if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names):
1037                continue
1038
1039            # Convert unnamed expressions to PropertyEQ with inherited names
1040            new_expressions = []
1041            for i, expr in enumerate(struct.expressions):
1042                if not isinstance(expr, exp.PropertyEQ):
1043                    # Create PropertyEQ: field_name := value, preserving the type from the inner expression
1044                    property_eq = exp.PropertyEQ(
1045                        this=field_names[i].copy(),
1046                        expression=expr,
1047                    )
1048                    property_eq.type = expr.type
1049                    new_expressions.append(property_eq)
1050                else:
1051                    new_expressions.append(expr)
1052
1053            struct.set("expressions", new_expressions)
1054
1055    return expression
def preprocess( transforms: list[typing.Callable[[sqlglot.expressions.core.Expr], sqlglot.expressions.core.Expr]], generator: Optional[Callable[[sqlglot.generator.Generator, sqlglot.expressions.core.Expr], str]] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.core.Expr], str]:
16def preprocess(
17    transforms: list[t.Callable[[exp.Expr], exp.Expr]],
18    generator: t.Callable[[Generator, exp.Expr], str] | None = None,
19) -> t.Callable[[Generator, exp.Expr], str]:
20    """
21    Creates a new transform by chaining a sequence of transformations and converts the resulting
22    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
23    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
24
25    Args:
26        transforms: sequence of transform functions. These will be called in order.
27
28    Returns:
29        Function that can be used as a generator transform.
30    """
31
32    def _to_sql(self, expression: exp.Expr) -> str:
33        expression_type = type(expression)
34
35        try:
36            expression = transforms[0](expression)
37            for transform in transforms[1:]:
38                expression = transform(expression)
39        except UnsupportedError as unsupported_error:
40            self.unsupported(str(unsupported_error))
41
42        if generator:
43            return generator(self, expression)
44
45        _sql_handler = getattr(self, expression.key + "_sql", None)
46        if _sql_handler:
47            return _sql_handler(expression)
48
49        transforms_handler = self.TRANSFORMS.get(type(expression))
50        if transforms_handler:
51            if expression_type is type(expression):
52                if isinstance(expression, exp.Func):
53                    return self.function_fallback_sql(expression)
54
55                # Ensures we don't enter an infinite loop. This can happen when the original expression
56                # has the same type as the final expression and there's no _sql method available for it,
57                # because then it'd re-enter _to_sql.
58                raise ValueError(
59                    f"Expr type {expression.__class__.__name__} requires a _sql method in order to be transformed."
60                )
61
62            return transforms_handler(self, expression)
63
64        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
65
66    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
 69def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp.Expr:
 70    if isinstance(expression, exp.Select):
 71        count = 0
 72        recursive_ctes = []
 73
 74        for unnest in expression.find_all(exp.Unnest):
 75            if (
 76                not isinstance(unnest.parent, (exp.From, exp.Join))
 77                or len(unnest.expressions) != 1
 78                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 79            ):
 80                continue
 81
 82            generate_date_array = unnest.expressions[0]
 83            start = generate_date_array.args.get("start")
 84            end = generate_date_array.args.get("end")
 85            step = generate_date_array.args.get("step")
 86
 87            if not start or not end or not isinstance(step, exp.Interval):
 88                continue
 89
 90            alias = unnest.args.get("alias")
 91            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 92
 93            start = exp.cast(start, "date")
 94            date_add = exp.func(
 95                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 96            )
 97            cast_date_add = exp.cast(date_add, "date")
 98
 99            cte_name = "_generated_dates" + (f"_{count}" if count else "")
100
101            base_query = exp.select(start.as_(column_name))
102            recursive_query = (
103                exp.select(cast_date_add)
104                .from_(cte_name)
105                .where(cast_date_add <= exp.cast(end, "date"))
106            )
107            cte_query = base_query.union(recursive_query, distinct=False)
108
109            generate_dates_query = exp.select(column_name).from_(cte_name)
110            unnest.replace(generate_dates_query.subquery(cte_name))
111
112            recursive_ctes.append(
113                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
114            )
115            count += 1
116
117        if recursive_ctes:
118            with_expression = expression.args.get("with_") or exp.With()
119            with_expression.set("recursive", True)
120            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
121            expression.set("with_", with_expression)
122
123    return expression
def unnest_generate_series( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
126def unnest_generate_series(expression: exp.Expr) -> exp.Expr:
127    """Unnests GENERATE_SERIES or SEQUENCE table references."""
128    this = expression.this
129    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
130        unnest = exp.Unnest(expressions=[this])
131        if expression.alias:
132            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
133
134        return unnest
135
136    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def eliminate_distinct_on( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
139def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr:
140    """
141    Convert SELECT DISTINCT ON statements to a subquery with a window function.
142
143    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
144
145    Args:
146        expression: the expression that will be transformed.
147
148    Returns:
149        The transformed expression.
150    """
151    if (
152        isinstance(expression, exp.Select)
153        and expression.args.get("distinct")
154        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
155    ):
156        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
157
158        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
159        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
160
161        order = expression.args.get("order")
162        if order:
163            window.set("order", order.pop())
164        else:
165            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
166
167        expression.select(exp.alias_(window, row_number_window_alias), copy=False)
168
169        # We add aliases to the projections so that we can safely reference them in the outer query
170        new_selects = []
171        taken_names = {row_number_window_alias}
172        for select in expression.selects[:-1]:
173            if select.is_star:
174                new_selects = [exp.Star()]
175                break
176
177            if not isinstance(select, exp.Alias):
178                alias = find_new_name(taken_names, select.output_name or "_col")
179                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
180                select = select.replace(exp.alias_(select, alias, quoted=quoted))
181
182            taken_names.add(select.output_name)
183            new_selects.append(select.args["alias"])
184
185        return (
186            exp.select(*new_selects, copy=False)
187            .from_(expression.subquery("_t", copy=False), copy=False)
188            .where(exp.column(row_number_window_alias).eq(1), copy=False)
189        )
190
191    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
194def eliminate_qualify(expression: exp.Expr) -> exp.Expr:
195    """
196    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
197
198    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
199    https://docs.snowflake.com/en/sql-reference/constructs/qualify
200
201    Some dialects don't support window functions in the WHERE clause, so we need to include them as
202    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
203    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
204    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
205    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
206    corresponding expression to avoid creating invalid column references.
207    """
208    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
209        taken = set(expression.named_selects)
210        for select in expression.selects:
211            if not select.alias_or_name:
212                alias = find_new_name(taken, "_c")
213                select.replace(exp.alias_(select, alias))
214                taken.add(alias)
215
216        def _select_alias_or_name(select: exp.Expr) -> str | exp.Column:
217            alias_or_name = select.alias_or_name
218            identifier = select.args.get("alias") or select.this
219            if isinstance(identifier, exp.Identifier):
220                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
221            return alias_or_name
222
223        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
224        qualify_filters = expression.args["qualify"].pop().this
225        expression_by_alias = {
226            select.alias: select.this
227            for select in expression.selects
228            if isinstance(select, exp.Alias)
229        }
230
231        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
232        for select_candidate in list(qualify_filters.find_all(select_candidates)):
233            if isinstance(select_candidate, exp.Window):
234                if expression_by_alias:
235                    for column in select_candidate.find_all(exp.Column):
236                        expr = expression_by_alias.get(column.name)
237                        if expr:
238                            column.replace(expr)
239
240                alias = find_new_name(expression.named_selects, "_w")
241                expression.select(exp.alias_(select_candidate, alias), copy=False)
242                column = exp.column(alias)
243
244                if isinstance(select_candidate.parent, exp.Qualify):
245                    qualify_filters = column
246                else:
247                    select_candidate.replace(column)
248            elif select_candidate.name not in expression.named_selects:
249                expression.select(select_candidate.copy(), copy=False)
250
251        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
252            qualify_filters, copy=False
253        )
254
255    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
258def remove_precision_parameterized_types(expression: exp.Expr) -> exp.Expr:
259    """
260    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
261    other expressions. This transforms removes the precision from parameterized types in expressions.
262    """
263    for node in expression.find_all(exp.DataType):
264        node.set(
265            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
266        )
267
268    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
271def unqualify_unnest(expression: exp.Expr) -> exp.Expr:
272    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
273    from sqlglot.optimizer.scope import find_all_in_scope
274
275    if isinstance(expression, exp.Select):
276        unnest_aliases = {
277            unnest.alias
278            for unnest in find_all_in_scope(expression, exp.Unnest)
279            if isinstance(unnest.parent, (exp.From, exp.Join))
280        }
281        if unnest_aliases:
282            for column in expression.find_all(exp.Column):
283                leftmost_part = column.parts[0]
284                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
285                    leftmost_part.pop()
286
287    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.core.Expr, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.core.Expr:
290def unnest_to_explode(
291    expression: exp.Expr,
292    unnest_using_arrays_zip: bool = True,
293) -> exp.Expr:
294    """Convert cross join unnest into lateral view explode."""
295
296    def _unnest_zip_exprs(
297        u: exp.Unnest, unnest_exprs: list[exp.Expr], has_multi_expr: bool
298    ) -> list[exp.Expr]:
299        if has_multi_expr:
300            if not unnest_using_arrays_zip:
301                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
302
303            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
304            zip_exprs: list[exp.Expr] = [exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)]
305            u.set("expressions", zip_exprs)
306            return zip_exprs
307        return unnest_exprs
308
309    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]:
310        if u.args.get("offset"):
311            return exp.Posexplode
312        return exp.Inline if has_multi_expr else exp.Explode
313
314    if isinstance(expression, exp.Select):
315        from_ = expression.args.get("from_")
316
317        if from_ and isinstance(from_.this, exp.Unnest):
318            unnest = from_.this
319            alias = unnest.args.get("alias")
320            exprs = unnest.expressions
321            has_multi_expr = len(exprs) > 1
322            this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
323
324            columns = alias.columns if alias else []
325            offset = unnest.args.get("offset")
326            if offset:
327                columns.insert(
328                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
329                )
330
331            unnest.replace(
332                exp.Table(
333                    this=_udtf_type(unnest, has_multi_expr)(this=this),
334                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
335                )
336            )
337
338        joins = expression.args.get("joins") or []
339        for join in list(joins):
340            join_expr = join.this
341
342            is_lateral = isinstance(join_expr, exp.Lateral)
343
344            unnest = join_expr.this if is_lateral else join_expr
345
346            if isinstance(unnest, exp.Unnest):
347                if is_lateral:
348                    alias = join_expr.args.get("alias")
349                else:
350                    alias = unnest.args.get("alias")
351                exprs = unnest.expressions
352                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
353                has_multi_expr = len(exprs) > 1
354                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
355
356                joins.remove(join)
357
358                alias_cols = alias.columns if alias else []
359
360                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
361                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
362                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
363
364                if not has_multi_expr and len(alias_cols) not in (1, 2):
365                    raise UnsupportedError(
366                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
367                    )
368
369                offset = unnest.args.get("offset")
370                if offset:
371                    alias_cols.insert(
372                        0,
373                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
374                    )
375
376                for e, column in zip(exprs, alias_cols):
377                    expression.append(
378                        "laterals",
379                        exp.Lateral(
380                            this=_udtf_type(unnest, has_multi_expr)(this=e),
381                            view=True,
382                            alias=exp.TableAlias(
383                                this=alias.this,  # type: ignore
384                                columns=alias_cols,
385                            ),
386                        ),
387                    )
388
389    return expression

Convert cross join unnest into lateral view explode.

def explode_projection_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.core.Expr], sqlglot.expressions.core.Expr]:
392def explode_projection_to_unnest(
393    index_offset: int = 0,
394) -> t.Callable[[exp.Expr], exp.Expr]:
395    """Convert explode/posexplode projections into unnests."""
396
397    def _explode_projection_to_unnest(expression: exp.Expr) -> exp.Expr:
398        if isinstance(expression, exp.Select):
399            from sqlglot.optimizer.scope import Scope
400
401            taken_select_names = set(expression.named_selects)
402            taken_source_names = {name for name, _ in Scope(expression).references}
403
404            def new_name(names: set[str], name: str) -> str:
405                name = find_new_name(names, name)
406                names.add(name)
407                return name
408
409            arrays: list[exp.Condition] = []
410            series_alias = new_name(taken_select_names, "pos")
411            series = exp.alias_(
412                exp.Unnest(
413                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
414                ),
415                new_name(taken_source_names, "_u"),
416                table=[series_alias],
417            )
418
419            # we use list here because expression.selects is mutated inside the loop
420            for select in list(expression.selects):
421                explode = select.find(exp.Explode)
422
423                if explode:
424                    pos_alias: t.Any = ""
425                    explode_alias: t.Any = ""
426
427                    if isinstance(select, exp.Alias):
428                        explode_alias = select.args["alias"]
429                        alias: exp.Expr = select
430                    elif isinstance(select, exp.Aliases):
431                        pos_alias = select.aliases[0]
432                        explode_alias = select.aliases[1]
433                        alias = select.replace(exp.alias_(select.this, "", copy=False))
434                    else:
435                        alias = select.replace(exp.alias_(select, ""))
436                        explode = alias.find(exp.Explode)
437                        assert explode
438
439                    is_posexplode = isinstance(explode, exp.Posexplode)
440                    explode_arg = explode.this
441
442                    if isinstance(explode, exp.ExplodeOuter):
443                        bracket = explode_arg[0]
444                        bracket.set("safe", True)
445                        bracket.set("offset", True)
446                        explode_arg = exp.func(
447                            "IF",
448                            exp.func(
449                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
450                            ).eq(0),
451                            exp.array(bracket, copy=False),
452                            explode_arg,
453                        )
454
455                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
456                    if isinstance(explode_arg, exp.Column):
457                        taken_select_names.add(explode_arg.output_name)
458
459                    unnest_source_alias = new_name(taken_source_names, "_u")
460
461                    if not explode_alias:
462                        explode_alias = new_name(taken_select_names, "col")
463
464                        if is_posexplode:
465                            pos_alias = new_name(taken_select_names, "pos")
466
467                    if not pos_alias:
468                        pos_alias = new_name(taken_select_names, "pos")
469
470                    alias.set("alias", exp.to_identifier(explode_alias))
471
472                    series_table_alias = series.args["alias"].this
473                    column = exp.If(
474                        this=exp.column(series_alias, table=series_table_alias).eq(
475                            exp.column(pos_alias, table=unnest_source_alias)
476                        ),
477                        true=exp.column(explode_alias, table=unnest_source_alias),
478                    )
479
480                    explode.replace(column)
481
482                    if is_posexplode:
483                        expressions = expression.expressions
484                        expressions.insert(
485                            expressions.index(alias) + 1,
486                            exp.If(
487                                this=exp.column(series_alias, table=series_table_alias).eq(
488                                    exp.column(pos_alias, table=unnest_source_alias)
489                                ),
490                                true=exp.column(pos_alias, table=unnest_source_alias),
491                            ).as_(pos_alias),
492                        )
493                        expression.set("expressions", expressions)
494
495                    if not arrays:
496                        if expression.args.get("from_"):
497                            expression.join(series, copy=False, join_type="CROSS")
498                        else:
499                            expression.from_(series, copy=False)
500
501                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
502                    arrays.append(size)
503
504                    # trino doesn't support left join unnest with on conditions
505                    # if it did, this would be much simpler
506                    expression.join(
507                        exp.alias_(
508                            exp.Unnest(
509                                expressions=[explode_arg.copy()],
510                                offset=exp.to_identifier(pos_alias),
511                            ),
512                            unnest_source_alias,
513                            table=[explode_alias],
514                        ),
515                        join_type="CROSS",
516                        copy=False,
517                    )
518
519                    if index_offset != 1:
520                        size = size - 1
521
522                    expression.where(
523                        exp.column(series_alias, table=series_table_alias)
524                        .eq(exp.column(pos_alias, table=unnest_source_alias))
525                        .or_(
526                            (exp.column(series_alias, table=series_table_alias) > size).and_(
527                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
528                            )
529                        ),
530                        copy=False,
531                    )
532
533            if arrays:
534                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
535
536                if index_offset != 1:
537                    end = end - (1 - index_offset)
538                series.expressions[0].set("end", end)
539
540        return expression
541
542    return _explode_projection_to_unnest

Convert explode/posexplode projections into unnests.

def add_within_group_for_percentiles( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
545def add_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
546    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
547    if (
548        isinstance(expression, exp.PERCENTILES)
549        and not isinstance(expression.parent, exp.WithinGroup)
550        and expression.expression
551    ):
552        column = expression.this.pop()
553        expression.set("this", expression.expression.pop())
554        order = exp.Order(expressions=[exp.Ordered(this=column)])
555        expression = exp.WithinGroup(this=expression, expression=order)
556
557    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
560def remove_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
561    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
562    if (
563        isinstance(expression, exp.WithinGroup)
564        and isinstance(expression.this, exp.PERCENTILES)
565        and isinstance(expression.expression, exp.Order)
566    ):
567        quantile = expression.this.this
568        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
569        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
570
571    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
574def add_recursive_cte_column_names(expression: exp.Expr) -> exp.Expr:
575    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
576    if isinstance(expression, exp.With) and expression.recursive:
577        next_name = name_sequence("_c_")
578
579        for cte in expression.expressions:
580            if not cte.args["alias"].columns:
581                query = cte.this
582                if isinstance(query, exp.SetOperation):
583                    query = query.this
584
585                cte.args["alias"].set(
586                    "columns",
587                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
588                )
589
590    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
593def epoch_cast_to_ts(expression: exp.Expr) -> exp.Expr:
594    """Replace 'epoch' in casts by the equivalent date literal."""
595    if (
596        isinstance(expression, (exp.Cast, exp.TryCast))
597        and expression.name.lower() == "epoch"
598        and expression.to.this in exp.DataType.TEMPORAL_TYPES
599    ):
600        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
601
602    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
605def eliminate_semi_and_anti_joins(expression: exp.Expr) -> exp.Expr:
606    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
607    if isinstance(expression, exp.Select):
608        for join in list(expression.args.get("joins") or []):
609            on = join.args.get("on")
610            if on and join.kind in ("SEMI", "ANTI"):
611                subquery = exp.select("1").from_(join.this).where(on)
612                exists: exp.Exists | exp.Not = exp.Exists(this=subquery)
613                if join.kind == "ANTI":
614                    exists = exists.not_(copy=False)
615
616                join.pop()
617                expression.where(exists, copy=False)
618
619    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
622def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr:
623    """
624    Converts a query with a FULL OUTER join to a union of identical queries that
625    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
626    for queries that have a single FULL OUTER join.
627    """
628    if isinstance(expression, exp.Select):
629        full_outer_joins = [
630            (index, join)
631            for index, join in enumerate(expression.args.get("joins") or [])
632            if join.side == "FULL"
633        ]
634
635        if len(full_outer_joins) == 1:
636            expression_copy = expression.copy()
637            expression.set("limit", None)
638            index, full_outer_join = full_outer_joins[0]
639
640            tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name)
641            join_conditions = full_outer_join.args.get("on") or exp.and_(
642                *[
643                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
644                    for col in full_outer_join.args.get("using")
645                ]
646            )
647
648            full_outer_join.set("side", "left")
649            anti_join_clause = (
650                exp.select("1").from_(expression.args["from_"]).where(join_conditions)
651            )
652            expression_copy.args["joins"][index].set("side", "right")
653            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
654            expression_copy.set("with_", None)  # remove CTEs from RIGHT side
655            expression.set("order", None)  # remove order by from LEFT side
656
657            return exp.union(expression, expression_copy, copy=False, distinct=False)
658
659    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
662def move_ctes_to_top_level(expression: E) -> E:
663    """
664    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
665    defined at the top-level, so for example queries like:
666
667        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
668
669    are invalid in those dialects. This transformation can be used to ensure all CTEs are
670    moved to the top level so that the final SQL code is valid from a syntax standpoint.
671
672    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
673    """
674    top_level_with = expression.args.get("with_")
675    for inner_with in expression.find_all(exp.With):
676        if inner_with.parent is expression:
677            continue
678
679        if not top_level_with:
680            top_level_with = inner_with.pop()
681            expression.set("with_", top_level_with)
682        else:
683            if inner_with.recursive:
684                top_level_with.set("recursive", True)
685
686            parent_cte = inner_with.find_ancestor(exp.CTE)
687            inner_with.pop()
688
689            if parent_cte:
690                i = top_level_with.expressions.index(parent_cte)
691                top_level_with.expressions[i:i] = inner_with.expressions
692                top_level_with.set("expressions", top_level_with.expressions)
693            else:
694                top_level_with.set(
695                    "expressions", top_level_with.expressions + inner_with.expressions
696                )
697
698    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
701def ensure_bools(expression: exp.Expr) -> exp.Expr:
702    """Converts numeric values used in conditions into explicit boolean expressions."""
703    from sqlglot.optimizer.canonicalize import ensure_bools
704
705    def _ensure_bool(node: exp.Expr) -> None:
706        if (
707            node.is_number
708            or (
709                not isinstance(node, exp.SubqueryPredicate)
710                and node.is_type(exp.DType.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
711            )
712            or (isinstance(node, exp.Column) and not node.type)
713        ):
714            node.replace(node.neq(0))
715
716    for node in expression.walk():
717        ensure_bools(node, _ensure_bool)
718
719    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
722def unqualify_columns(expression: exp.Expr) -> exp.Expr:
723    for column in expression.find_all(exp.Column):
724        # We only wanna pop off the table, db, catalog args
725        for part in column.parts[:-1]:
726            part.pop()
727
728    return expression
def remove_unique_constraints( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
731def remove_unique_constraints(expression: exp.Expr) -> exp.Expr:
732    assert isinstance(expression, exp.Create)
733    for constraint in expression.find_all(exp.UniqueColumnConstraint):
734        if constraint.parent:
735            constraint.parent.pop()
736
737    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.core.Expr, tmp_storage_provider: Callable[[sqlglot.expressions.core.Expr], sqlglot.expressions.core.Expr] = <function <lambda>>) -> sqlglot.expressions.core.Expr:
740def ctas_with_tmp_tables_to_create_tmp_view(
741    expression: exp.Expr,
742    tmp_storage_provider: t.Callable[[exp.Expr], exp.Expr] = lambda e: e,
743) -> exp.Expr:
744    assert isinstance(expression, exp.Create)
745    properties = expression.args.get("properties")
746    temporary = any(
747        isinstance(prop, exp.TemporaryProperty)
748        for prop in (properties.expressions if properties else [])
749    )
750
751    # CTAS with temp tables map to CREATE TEMPORARY VIEW
752    if expression.kind == "TABLE" and temporary:
753        if expression.expression:
754            return exp.Create(
755                kind="TEMPORARY VIEW",
756                this=expression.this,
757                expression=expression.expression,
758            )
759        return tmp_storage_provider(expression)
760
761    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
764def move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
765    """
766    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
767    PARTITIONED BY value is an array of column names, they are transformed into a schema.
768    The corresponding columns are removed from the create statement.
769    """
770    assert isinstance(expression, exp.Create)
771    has_schema = isinstance(expression.this, exp.Schema)
772    is_partitionable = expression.kind in {"TABLE", "VIEW"}
773
774    if has_schema and is_partitionable:
775        prop = expression.find(exp.PartitionedByProperty)
776        if prop and prop.this and not isinstance(prop.this, exp.Schema):
777            schema = expression.this
778            columns = {v.name.upper() for v in prop.this.expressions}
779            partitions = [col for col in schema.expressions if col.name.upper() in columns]
780            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
781            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
782            expression.set("this", schema)
783
784    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
787def move_partitioned_by_to_schema_columns(expression: exp.Expr) -> exp.Expr:
788    """
789    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
790
791    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
792    """
793    assert isinstance(expression, exp.Create)
794    prop = expression.find(exp.PartitionedByProperty)
795    if (
796        prop
797        and prop.this
798        and isinstance(prop.this, exp.Schema)
799        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
800    ):
801        prop_this = exp.Tuple(
802            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
803        )
804        schema = expression.this
805        for e in prop.this.expressions:
806            schema.append("expressions", e)
807        prop.set("this", prop_this)
808
809    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
812def struct_kv_to_alias(expression: exp.Expr) -> exp.Expr:
813    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
814    if isinstance(expression, exp.Struct):
815        expression.set(
816            "expressions",
817            [
818                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
819                for e in expression.expressions
820            ],
821        )
822
823    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
826def eliminate_join_marks(expression: exp.Expr) -> exp.Expr:
827    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
828
829    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
830
831    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
832
833    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
834
835    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
836
837    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
838
839    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
840
841    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
842
843    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
844
845    -- example with WHERE
846    SELECT d.department_name, sum(e.salary) as total_salary
847    FROM departments d, employees e
848    WHERE e.department_id(+) = d.department_id
849    group by department_name
850
851    -- example of left correlation in select
852    SELECT d.department_name, (
853        SELECT SUM(e.salary)
854            FROM employees e
855            WHERE e.department_id(+) = d.department_id) AS total_salary
856    FROM departments d;
857
858    -- example of left correlation in from
859    SELECT d.department_name, t.total_salary
860    FROM departments d, (
861            SELECT SUM(e.salary) AS total_salary
862            FROM employees e
863            WHERE e.department_id(+) = d.department_id
864        ) t
865    """
866
867    from sqlglot.optimizer.scope import traverse_scope
868    from sqlglot.optimizer.normalize import normalize, normalized
869    from collections import defaultdict
870
871    # we go in reverse to check the main query for left correlation
872    for scope in reversed(traverse_scope(expression)):
873        query = scope.expression
874
875        where = query.args.get("where")
876        joins = query.args.get("joins", [])
877
878        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
879            continue
880
881        # knockout: we do not support left correlation (see point 2)
882        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
883
884        # make sure we have AND of ORs to have clear join terms
885        where = normalize(where.this)
886        assert normalized(where), "Cannot normalize JOIN predicates"
887
888        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
889        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
890            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
891
892            left_join_table = set(col.table for col in join_cols)
893            if not left_join_table:
894                continue
895
896            assert not (len(left_join_table) > 1), (
897                "Cannot combine JOIN predicates from different tables"
898            )
899
900            for col in join_cols:
901                col.set("join_mark", False)
902
903            joins_ons[left_join_table.pop()].append(cond)
904
905        old_joins = {join.alias_or_name: join for join in joins}
906        new_joins = {}
907        query_from = query.args["from_"]
908
909        for table, predicates in joins_ons.items():
910            join_what = old_joins.get(table, query_from).this.copy()
911            new_joins[join_what.alias_or_name] = exp.Join(
912                this=join_what, on=exp.and_(*predicates), kind="LEFT"
913            )
914
915            for p in predicates:
916                while isinstance(p.parent, exp.Paren):
917                    p.parent.replace(p)
918
919                parent = p.parent
920                p.pop()
921                if isinstance(parent, exp.Binary):
922                    left = parent.args.get("this")
923                    parent.replace(parent.right if left is None else left)
924                elif isinstance(parent, exp.Where):
925                    parent.pop()
926
927        if query_from.alias_or_name in new_joins:
928            only_old_joins = old_joins.keys() - new_joins.keys()
929            assert len(only_old_joins) >= 1, (
930                "Cannot determine which table to use in the new FROM clause"
931            )
932
933            new_from_name = list(only_old_joins)[0]
934            query.set("from_", exp.From(this=old_joins[new_from_name].this))
935
936        if new_joins:
937            for n, j in old_joins.items():  # preserve any other joins
938                if n not in new_joins and n != query.args["from_"].name:
939                    if not j.kind:
940                        j.set("kind", "CROSS")
941                    new_joins[n] = j
942            query.set("joins", list(new_joins.values()))
943
944    return expression

https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178

  1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.

  2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.

The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.

You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.

The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.

A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.

A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.

A WHERE condition cannot compare any column marked with the (+) operator with a subquery.

-- example with WHERE SELECT d.department_name, sum(e.salary) as total_salary FROM departments d, employees e WHERE e.department_id(+) = d.department_id group by department_name

-- example of left correlation in select SELECT d.department_name, ( SELECT SUM(e.salary) FROM employees e WHERE e.department_id(+) = d.department_id) AS total_salary FROM departments d;

-- example of left correlation in from SELECT d.department_name, t.total_salary FROM departments d, ( SELECT SUM(e.salary) AS total_salary FROM employees e WHERE e.department_id(+) = d.department_id ) t

def any_to_exists( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
947def any_to_exists(expression: exp.Expr) -> exp.Expr:
948    """
949    Transform ANY operator to Spark's EXISTS
950
951    For example,
952        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
953        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
954
955    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
956    transformation
957    """
958    if isinstance(expression, exp.Select):
959        for any_expr in expression.find_all(exp.Any):
960            this = any_expr.this
961            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
962                continue
963
964            binop = any_expr.parent
965            if isinstance(binop, exp.Binary):
966                lambda_arg = exp.to_identifier("x")
967                any_expr.replace(lambda_arg)
968                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
969                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
970
971    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation

def eliminate_window_clause( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
 974def eliminate_window_clause(expression: exp.Expr) -> exp.Expr:
 975    """Eliminates the `WINDOW` query clause by inling each named window."""
 976    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 977        from sqlglot.optimizer.scope import find_all_in_scope
 978
 979        windows = expression.args["windows"]
 980        expression.set("windows", None)
 981
 982        window_expression: dict[str, exp.Expr] = {}
 983
 984        def _inline_inherited_window(window: exp.Expr) -> None:
 985            inherited_window = window_expression.get(window.alias.lower())
 986            if not inherited_window:
 987                return
 988
 989            window.set("alias", None)
 990            for key in ("partition_by", "order", "spec"):
 991                arg = inherited_window.args.get(key)
 992                if arg:
 993                    window.set(key, arg.copy())
 994
 995        for window in windows:
 996            _inline_inherited_window(window)
 997            window_expression[window.name.lower()] = window
 998
 999        for window in find_all_in_scope(expression, exp.Window):
1000            _inline_inherited_window(window)
1001
1002    return expression

Eliminates the WINDOW query clause by inling each named window.

def inherit_struct_field_names( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
1005def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr:
1006    """
1007    Inherit field names from the first struct in an array.
1008
1009    BigQuery supports implicitly inheriting names from the first STRUCT in an array:
1010
1011    Example:
1012        ARRAY[
1013          STRUCT('Alice' AS name, 85 AS score),  -- defines names
1014          STRUCT('Bob', 92),                     -- inherits names
1015          STRUCT('Diana', 95)                    -- inherits names
1016        ]
1017
1018    This transformation makes the field names explicit on all structs by adding
1019    PropertyEQ nodes, in order to facilitate transpilation to other dialects.
1020
1021    Args:
1022        expression: The expression tree to transform
1023
1024    Returns:
1025        The modified expression with field names inherited in all structs
1026    """
1027    if (
1028        isinstance(expression, exp.Array)
1029        and expression.args.get("struct_name_inheritance")
1030        and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct)
1031        and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions)
1032    ):
1033        field_names = [fld.this for fld in first_item.expressions]
1034
1035        # Apply field names to subsequent structs that don't have them
1036        for struct in expression.expressions[1:]:
1037            if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names):
1038                continue
1039
1040            # Convert unnamed expressions to PropertyEQ with inherited names
1041            new_expressions = []
1042            for i, expr in enumerate(struct.expressions):
1043                if not isinstance(expr, exp.PropertyEQ):
1044                    # Create PropertyEQ: field_name := value, preserving the type from the inner expression
1045                    property_eq = exp.PropertyEQ(
1046                        this=field_names[i].copy(),
1047                        expression=expr,
1048                    )
1049                    property_eq.type = expr.type
1050                    new_expressions.append(property_eq)
1051                else:
1052                    new_expressions.append(expr)
1053
1054            struct.set("expressions", new_expressions)
1055
1056    return expression

Inherit field names from the first struct in an array.

BigQuery supports implicitly inheriting names from the first STRUCT in an array:

Example:

ARRAY[ STRUCT('Alice' AS name, 85 AS score), -- defines names STRUCT('Bob', 92), -- inherits names STRUCT('Diana', 95) -- inherits names ]

This transformation makes the field names explicit on all structs by adding PropertyEQ nodes, in order to facilitate transpilation to other dialects.

Arguments:
  • expression: The expression tree to transform
Returns:

The modified expression with field names inherited in all structs