Edit on GitHub

sqlglot.transforms

   1from __future__ import annotations
   2
   3import typing as t
   4
   5from sqlglot import expressions as exp
   6from sqlglot.errors import UnsupportedError
   7from sqlglot.helper import find_new_name, name_sequence, seq_get
   8
   9
  10if t.TYPE_CHECKING:
  11    from sqlglot._typing import E
  12    from sqlglot.generator import Generator
  13
  14
  15def preprocess(
  16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
  17    generator: t.Optional[t.Callable[[Generator, exp.Expression], str]] = None,
  18) -> t.Callable[[Generator, exp.Expression], str]:
  19    """
  20    Creates a new transform by chaining a sequence of transformations and converts the resulting
  21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
  22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
  23
  24    Args:
  25        transforms: sequence of transform functions. These will be called in order.
  26
  27    Returns:
  28        Function that can be used as a generator transform.
  29    """
  30
  31    def _to_sql(self, expression: exp.Expression) -> str:
  32        expression_type = type(expression)
  33
  34        try:
  35            expression = transforms[0](expression)
  36            for transform in transforms[1:]:
  37                expression = transform(expression)
  38        except UnsupportedError as unsupported_error:
  39            self.unsupported(str(unsupported_error))
  40
  41        if generator:
  42            return generator(self, expression)
  43
  44        _sql_handler = getattr(self, expression.key + "_sql", None)
  45        if _sql_handler:
  46            return _sql_handler(expression)
  47
  48        transforms_handler = self.TRANSFORMS.get(type(expression))
  49        if transforms_handler:
  50            if expression_type is type(expression):
  51                if isinstance(expression, exp.Func):
  52                    return self.function_fallback_sql(expression)
  53
  54                # Ensures we don't enter an infinite loop. This can happen when the original expression
  55                # has the same type as the final expression and there's no _sql method available for it,
  56                # because then it'd re-enter _to_sql.
  57                raise ValueError(
  58                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
  59                )
  60
  61            return transforms_handler(self, expression)
  62
  63        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
  64
  65    return _to_sql
  66
  67
  68def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
  69    if isinstance(expression, exp.Select):
  70        count = 0
  71        recursive_ctes = []
  72
  73        for unnest in expression.find_all(exp.Unnest):
  74            if (
  75                not isinstance(unnest.parent, (exp.From, exp.Join))
  76                or len(unnest.expressions) != 1
  77                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
  78            ):
  79                continue
  80
  81            generate_date_array = unnest.expressions[0]
  82            start = generate_date_array.args.get("start")
  83            end = generate_date_array.args.get("end")
  84            step = generate_date_array.args.get("step")
  85
  86            if not start or not end or not isinstance(step, exp.Interval):
  87                continue
  88
  89            alias = unnest.args.get("alias")
  90            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
  91
  92            start = exp.cast(start, "date")
  93            date_add = exp.func(
  94                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
  95            )
  96            cast_date_add = exp.cast(date_add, "date")
  97
  98            cte_name = "_generated_dates" + (f"_{count}" if count else "")
  99
 100            base_query = exp.select(start.as_(column_name))
 101            recursive_query = (
 102                exp.select(cast_date_add)
 103                .from_(cte_name)
 104                .where(cast_date_add <= exp.cast(end, "date"))
 105            )
 106            cte_query = base_query.union(recursive_query, distinct=False)
 107
 108            generate_dates_query = exp.select(column_name).from_(cte_name)
 109            unnest.replace(generate_dates_query.subquery(cte_name))
 110
 111            recursive_ctes.append(
 112                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
 113            )
 114            count += 1
 115
 116        if recursive_ctes:
 117            with_expression = expression.args.get("with_") or exp.With()
 118            with_expression.set("recursive", True)
 119            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
 120            expression.set("with_", with_expression)
 121
 122    return expression
 123
 124
 125def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
 126    """Unnests GENERATE_SERIES or SEQUENCE table references."""
 127    this = expression.this
 128    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
 129        unnest = exp.Unnest(expressions=[this])
 130        if expression.alias:
 131            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
 132
 133        return unnest
 134
 135    return expression
 136
 137
 138def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 139    """
 140    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 141
 142    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 143
 144    Args:
 145        expression: the expression that will be transformed.
 146
 147    Returns:
 148        The transformed expression.
 149    """
 150    if (
 151        isinstance(expression, exp.Select)
 152        and expression.args.get("distinct")
 153        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
 154    ):
 155        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
 156
 157        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 158        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 159
 160        order = expression.args.get("order")
 161        if order:
 162            window.set("order", order.pop())
 163        else:
 164            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 165
 166        window = exp.alias_(window, row_number_window_alias)
 167        expression.select(window, copy=False)
 168
 169        # We add aliases to the projections so that we can safely reference them in the outer query
 170        new_selects = []
 171        taken_names = {row_number_window_alias}
 172        for select in expression.selects[:-1]:
 173            if select.is_star:
 174                new_selects = [exp.Star()]
 175                break
 176
 177            if not isinstance(select, exp.Alias):
 178                alias = find_new_name(taken_names, select.output_name or "_col")
 179                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
 180                select = select.replace(exp.alias_(select, alias, quoted=quoted))
 181
 182            taken_names.add(select.output_name)
 183            new_selects.append(select.args["alias"])
 184
 185        return (
 186            exp.select(*new_selects, copy=False)
 187            .from_(expression.subquery("_t", copy=False), copy=False)
 188            .where(exp.column(row_number_window_alias).eq(1), copy=False)
 189        )
 190
 191    return expression
 192
 193
 194def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 195    """
 196    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 197
 198    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 199    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 200
 201    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 202    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 203    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 204    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 205    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 206    corresponding expression to avoid creating invalid column references.
 207    """
 208    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 209        taken = set(expression.named_selects)
 210        for select in expression.selects:
 211            if not select.alias_or_name:
 212                alias = find_new_name(taken, "_c")
 213                select.replace(exp.alias_(select, alias))
 214                taken.add(alias)
 215
 216        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
 217            alias_or_name = select.alias_or_name
 218            identifier = select.args.get("alias") or select.this
 219            if isinstance(identifier, exp.Identifier):
 220                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
 221            return alias_or_name
 222
 223        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
 224        qualify_filters = expression.args["qualify"].pop().this
 225        expression_by_alias = {
 226            select.alias: select.this
 227            for select in expression.selects
 228            if isinstance(select, exp.Alias)
 229        }
 230
 231        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
 232        for select_candidate in list(qualify_filters.find_all(select_candidates)):
 233            if isinstance(select_candidate, exp.Window):
 234                if expression_by_alias:
 235                    for column in select_candidate.find_all(exp.Column):
 236                        expr = expression_by_alias.get(column.name)
 237                        if expr:
 238                            column.replace(expr)
 239
 240                alias = find_new_name(expression.named_selects, "_w")
 241                expression.select(exp.alias_(select_candidate, alias), copy=False)
 242                column = exp.column(alias)
 243
 244                if isinstance(select_candidate.parent, exp.Qualify):
 245                    qualify_filters = column
 246                else:
 247                    select_candidate.replace(column)
 248            elif select_candidate.name not in expression.named_selects:
 249                expression.select(select_candidate.copy(), copy=False)
 250
 251        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
 252            qualify_filters, copy=False
 253        )
 254
 255    return expression
 256
 257
 258def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
 259    """
 260    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
 261    other expressions. This transforms removes the precision from parameterized types in expressions.
 262    """
 263    for node in expression.find_all(exp.DataType):
 264        node.set(
 265            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
 266        )
 267
 268    return expression
 269
 270
 271def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
 272    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
 273    from sqlglot.optimizer.scope import find_all_in_scope
 274
 275    if isinstance(expression, exp.Select):
 276        unnest_aliases = {
 277            unnest.alias
 278            for unnest in find_all_in_scope(expression, exp.Unnest)
 279            if isinstance(unnest.parent, (exp.From, exp.Join))
 280        }
 281        if unnest_aliases:
 282            for column in expression.find_all(exp.Column):
 283                leftmost_part = column.parts[0]
 284                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
 285                    leftmost_part.pop()
 286
 287    return expression
 288
 289
 290def unnest_to_explode(
 291    expression: exp.Expression,
 292    unnest_using_arrays_zip: bool = True,
 293) -> exp.Expression:
 294    """Convert cross join unnest into lateral view explode."""
 295
 296    def _unnest_zip_exprs(
 297        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
 298    ) -> t.List[exp.Expression]:
 299        if has_multi_expr:
 300            if not unnest_using_arrays_zip:
 301                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
 302
 303            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
 304            zip_exprs: t.List[exp.Expression] = [
 305                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
 306            ]
 307            u.set("expressions", zip_exprs)
 308            return zip_exprs
 309        return unnest_exprs
 310
 311    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
 312        if u.args.get("offset"):
 313            return exp.Posexplode
 314        return exp.Inline if has_multi_expr else exp.Explode
 315
 316    if isinstance(expression, exp.Select):
 317        from_ = expression.args.get("from_")
 318
 319        if from_ and isinstance(from_.this, exp.Unnest):
 320            unnest = from_.this
 321            alias = unnest.args.get("alias")
 322            exprs = unnest.expressions
 323            has_multi_expr = len(exprs) > 1
 324            this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 325
 326            columns = alias.columns if alias else []
 327            offset = unnest.args.get("offset")
 328            if offset:
 329                columns.insert(
 330                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
 331                )
 332
 333            unnest.replace(
 334                exp.Table(
 335                    this=_udtf_type(unnest, has_multi_expr)(this=this),
 336                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
 337                )
 338            )
 339
 340        joins = expression.args.get("joins") or []
 341        for join in list(joins):
 342            join_expr = join.this
 343
 344            is_lateral = isinstance(join_expr, exp.Lateral)
 345
 346            unnest = join_expr.this if is_lateral else join_expr
 347
 348            if isinstance(unnest, exp.Unnest):
 349                if is_lateral:
 350                    alias = join_expr.args.get("alias")
 351                else:
 352                    alias = unnest.args.get("alias")
 353                exprs = unnest.expressions
 354                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
 355                has_multi_expr = len(exprs) > 1
 356                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 357
 358                joins.remove(join)
 359
 360                alias_cols = alias.columns if alias else []
 361
 362                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
 363                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
 364                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
 365
 366                if not has_multi_expr and len(alias_cols) not in (1, 2):
 367                    raise UnsupportedError(
 368                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
 369                    )
 370
 371                offset = unnest.args.get("offset")
 372                if offset:
 373                    alias_cols.insert(
 374                        0,
 375                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
 376                    )
 377
 378                for e, column in zip(exprs, alias_cols):
 379                    expression.append(
 380                        "laterals",
 381                        exp.Lateral(
 382                            this=_udtf_type(unnest, has_multi_expr)(this=e),
 383                            view=True,
 384                            alias=exp.TableAlias(
 385                                this=alias.this,  # type: ignore
 386                                columns=alias_cols,
 387                            ),
 388                        ),
 389                    )
 390
 391    return expression
 392
 393
 394def explode_projection_to_unnest(
 395    index_offset: int = 0,
 396) -> t.Callable[[exp.Expression], exp.Expression]:
 397    """Convert explode/posexplode projections into unnests."""
 398
 399    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
 400        if isinstance(expression, exp.Select):
 401            from sqlglot.optimizer.scope import Scope
 402
 403            taken_select_names = set(expression.named_selects)
 404            taken_source_names = {name for name, _ in Scope(expression).references}
 405
 406            def new_name(names: t.Set[str], name: str) -> str:
 407                name = find_new_name(names, name)
 408                names.add(name)
 409                return name
 410
 411            arrays: t.List[exp.Condition] = []
 412            series_alias = new_name(taken_select_names, "pos")
 413            series = exp.alias_(
 414                exp.Unnest(
 415                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
 416                ),
 417                new_name(taken_source_names, "_u"),
 418                table=[series_alias],
 419            )
 420
 421            # we use list here because expression.selects is mutated inside the loop
 422            for select in list(expression.selects):
 423                explode = select.find(exp.Explode)
 424
 425                if explode:
 426                    pos_alias = ""
 427                    explode_alias = ""
 428
 429                    if isinstance(select, exp.Alias):
 430                        explode_alias = select.args["alias"]
 431                        alias = select
 432                    elif isinstance(select, exp.Aliases):
 433                        pos_alias = select.aliases[0]
 434                        explode_alias = select.aliases[1]
 435                        alias = select.replace(exp.alias_(select.this, "", copy=False))
 436                    else:
 437                        alias = select.replace(exp.alias_(select, ""))
 438                        explode = alias.find(exp.Explode)
 439                        assert explode
 440
 441                    is_posexplode = isinstance(explode, exp.Posexplode)
 442                    explode_arg = explode.this
 443
 444                    if isinstance(explode, exp.ExplodeOuter):
 445                        bracket = explode_arg[0]
 446                        bracket.set("safe", True)
 447                        bracket.set("offset", True)
 448                        explode_arg = exp.func(
 449                            "IF",
 450                            exp.func(
 451                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
 452                            ).eq(0),
 453                            exp.array(bracket, copy=False),
 454                            explode_arg,
 455                        )
 456
 457                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
 458                    if isinstance(explode_arg, exp.Column):
 459                        taken_select_names.add(explode_arg.output_name)
 460
 461                    unnest_source_alias = new_name(taken_source_names, "_u")
 462
 463                    if not explode_alias:
 464                        explode_alias = new_name(taken_select_names, "col")
 465
 466                        if is_posexplode:
 467                            pos_alias = new_name(taken_select_names, "pos")
 468
 469                    if not pos_alias:
 470                        pos_alias = new_name(taken_select_names, "pos")
 471
 472                    alias.set("alias", exp.to_identifier(explode_alias))
 473
 474                    series_table_alias = series.args["alias"].this
 475                    column = exp.If(
 476                        this=exp.column(series_alias, table=series_table_alias).eq(
 477                            exp.column(pos_alias, table=unnest_source_alias)
 478                        ),
 479                        true=exp.column(explode_alias, table=unnest_source_alias),
 480                    )
 481
 482                    explode.replace(column)
 483
 484                    if is_posexplode:
 485                        expressions = expression.expressions
 486                        expressions.insert(
 487                            expressions.index(alias) + 1,
 488                            exp.If(
 489                                this=exp.column(series_alias, table=series_table_alias).eq(
 490                                    exp.column(pos_alias, table=unnest_source_alias)
 491                                ),
 492                                true=exp.column(pos_alias, table=unnest_source_alias),
 493                            ).as_(pos_alias),
 494                        )
 495                        expression.set("expressions", expressions)
 496
 497                    if not arrays:
 498                        if expression.args.get("from_"):
 499                            expression.join(series, copy=False, join_type="CROSS")
 500                        else:
 501                            expression.from_(series, copy=False)
 502
 503                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
 504                    arrays.append(size)
 505
 506                    # trino doesn't support left join unnest with on conditions
 507                    # if it did, this would be much simpler
 508                    expression.join(
 509                        exp.alias_(
 510                            exp.Unnest(
 511                                expressions=[explode_arg.copy()],
 512                                offset=exp.to_identifier(pos_alias),
 513                            ),
 514                            unnest_source_alias,
 515                            table=[explode_alias],
 516                        ),
 517                        join_type="CROSS",
 518                        copy=False,
 519                    )
 520
 521                    if index_offset != 1:
 522                        size = size - 1
 523
 524                    expression.where(
 525                        exp.column(series_alias, table=series_table_alias)
 526                        .eq(exp.column(pos_alias, table=unnest_source_alias))
 527                        .or_(
 528                            (exp.column(series_alias, table=series_table_alias) > size).and_(
 529                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
 530                            )
 531                        ),
 532                        copy=False,
 533                    )
 534
 535            if arrays:
 536                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
 537
 538                if index_offset != 1:
 539                    end = end - (1 - index_offset)
 540                series.expressions[0].set("end", end)
 541
 542        return expression
 543
 544    return _explode_projection_to_unnest
 545
 546
 547def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 548    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
 549    if (
 550        isinstance(expression, exp.PERCENTILES)
 551        and not isinstance(expression.parent, exp.WithinGroup)
 552        and expression.expression
 553    ):
 554        column = expression.this.pop()
 555        expression.set("this", expression.expression.pop())
 556        order = exp.Order(expressions=[exp.Ordered(this=column)])
 557        expression = exp.WithinGroup(this=expression, expression=order)
 558
 559    return expression
 560
 561
 562def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 563    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
 564    if (
 565        isinstance(expression, exp.WithinGroup)
 566        and isinstance(expression.this, exp.PERCENTILES)
 567        and isinstance(expression.expression, exp.Order)
 568    ):
 569        quantile = expression.this.this
 570        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
 571        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
 572
 573    return expression
 574
 575
 576def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
 577    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
 578    if isinstance(expression, exp.With) and expression.recursive:
 579        next_name = name_sequence("_c_")
 580
 581        for cte in expression.expressions:
 582            if not cte.args["alias"].columns:
 583                query = cte.this
 584                if isinstance(query, exp.SetOperation):
 585                    query = query.this
 586
 587                cte.args["alias"].set(
 588                    "columns",
 589                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
 590                )
 591
 592    return expression
 593
 594
 595def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
 596    """Replace 'epoch' in casts by the equivalent date literal."""
 597    if (
 598        isinstance(expression, (exp.Cast, exp.TryCast))
 599        and expression.name.lower() == "epoch"
 600        and expression.to.this in exp.DataType.TEMPORAL_TYPES
 601    ):
 602        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
 603
 604    return expression
 605
 606
 607def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
 608    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
 609    if isinstance(expression, exp.Select):
 610        for join in expression.args.get("joins") or []:
 611            on = join.args.get("on")
 612            if on and join.kind in ("SEMI", "ANTI"):
 613                subquery = exp.select("1").from_(join.this).where(on)
 614                exists = exp.Exists(this=subquery)
 615                if join.kind == "ANTI":
 616                    exists = exists.not_(copy=False)
 617
 618                join.pop()
 619                expression.where(exists, copy=False)
 620
 621    return expression
 622
 623
 624def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
 625    """
 626    Converts a query with a FULL OUTER join to a union of identical queries that
 627    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
 628    for queries that have a single FULL OUTER join.
 629    """
 630    if isinstance(expression, exp.Select):
 631        full_outer_joins = [
 632            (index, join)
 633            for index, join in enumerate(expression.args.get("joins") or [])
 634            if join.side == "FULL"
 635        ]
 636
 637        if len(full_outer_joins) == 1:
 638            expression_copy = expression.copy()
 639            expression.set("limit", None)
 640            index, full_outer_join = full_outer_joins[0]
 641
 642            tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name)
 643            join_conditions = full_outer_join.args.get("on") or exp.and_(
 644                *[
 645                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
 646                    for col in full_outer_join.args.get("using")
 647                ]
 648            )
 649
 650            full_outer_join.set("side", "left")
 651            anti_join_clause = (
 652                exp.select("1").from_(expression.args["from_"]).where(join_conditions)
 653            )
 654            expression_copy.args["joins"][index].set("side", "right")
 655            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
 656            expression_copy.set("with_", None)  # remove CTEs from RIGHT side
 657            expression.set("order", None)  # remove order by from LEFT side
 658
 659            return exp.union(expression, expression_copy, copy=False, distinct=False)
 660
 661    return expression
 662
 663
 664def move_ctes_to_top_level(expression: E) -> E:
 665    """
 666    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
 667    defined at the top-level, so for example queries like:
 668
 669        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
 670
 671    are invalid in those dialects. This transformation can be used to ensure all CTEs are
 672    moved to the top level so that the final SQL code is valid from a syntax standpoint.
 673
 674    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
 675    """
 676    top_level_with = expression.args.get("with_")
 677    for inner_with in expression.find_all(exp.With):
 678        if inner_with.parent is expression:
 679            continue
 680
 681        if not top_level_with:
 682            top_level_with = inner_with.pop()
 683            expression.set("with_", top_level_with)
 684        else:
 685            if inner_with.recursive:
 686                top_level_with.set("recursive", True)
 687
 688            parent_cte = inner_with.find_ancestor(exp.CTE)
 689            inner_with.pop()
 690
 691            if parent_cte:
 692                i = top_level_with.expressions.index(parent_cte)
 693                top_level_with.expressions[i:i] = inner_with.expressions
 694                top_level_with.set("expressions", top_level_with.expressions)
 695            else:
 696                top_level_with.set(
 697                    "expressions", top_level_with.expressions + inner_with.expressions
 698                )
 699
 700    return expression
 701
 702
 703def ensure_bools(expression: exp.Expression) -> exp.Expression:
 704    """Converts numeric values used in conditions into explicit boolean expressions."""
 705    from sqlglot.optimizer.canonicalize import ensure_bools
 706
 707    def _ensure_bool(node: exp.Expression) -> None:
 708        if (
 709            node.is_number
 710            or (
 711                not isinstance(node, exp.SubqueryPredicate)
 712                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
 713            )
 714            or (isinstance(node, exp.Column) and not node.type)
 715        ):
 716            node.replace(node.neq(0))
 717
 718    for node in expression.walk():
 719        ensure_bools(node, _ensure_bool)
 720
 721    return expression
 722
 723
 724def unqualify_columns(expression: exp.Expression) -> exp.Expression:
 725    for column in expression.find_all(exp.Column):
 726        # We only wanna pop off the table, db, catalog args
 727        for part in column.parts[:-1]:
 728            part.pop()
 729
 730    return expression
 731
 732
 733def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
 734    assert isinstance(expression, exp.Create)
 735    for constraint in expression.find_all(exp.UniqueColumnConstraint):
 736        if constraint.parent:
 737            constraint.parent.pop()
 738
 739    return expression
 740
 741
 742def ctas_with_tmp_tables_to_create_tmp_view(
 743    expression: exp.Expression,
 744    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
 745) -> exp.Expression:
 746    assert isinstance(expression, exp.Create)
 747    properties = expression.args.get("properties")
 748    temporary = any(
 749        isinstance(prop, exp.TemporaryProperty)
 750        for prop in (properties.expressions if properties else [])
 751    )
 752
 753    # CTAS with temp tables map to CREATE TEMPORARY VIEW
 754    if expression.kind == "TABLE" and temporary:
 755        if expression.expression:
 756            return exp.Create(
 757                kind="TEMPORARY VIEW",
 758                this=expression.this,
 759                expression=expression.expression,
 760            )
 761        return tmp_storage_provider(expression)
 762
 763    return expression
 764
 765
 766def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
 767    """
 768    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
 769    PARTITIONED BY value is an array of column names, they are transformed into a schema.
 770    The corresponding columns are removed from the create statement.
 771    """
 772    assert isinstance(expression, exp.Create)
 773    has_schema = isinstance(expression.this, exp.Schema)
 774    is_partitionable = expression.kind in {"TABLE", "VIEW"}
 775
 776    if has_schema and is_partitionable:
 777        prop = expression.find(exp.PartitionedByProperty)
 778        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 779            schema = expression.this
 780            columns = {v.name.upper() for v in prop.this.expressions}
 781            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 782            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 783            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 784            expression.set("this", schema)
 785
 786    return expression
 787
 788
 789def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
 790    """
 791    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
 792
 793    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
 794    """
 795    assert isinstance(expression, exp.Create)
 796    prop = expression.find(exp.PartitionedByProperty)
 797    if (
 798        prop
 799        and prop.this
 800        and isinstance(prop.this, exp.Schema)
 801        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
 802    ):
 803        prop_this = exp.Tuple(
 804            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
 805        )
 806        schema = expression.this
 807        for e in prop.this.expressions:
 808            schema.append("expressions", e)
 809        prop.set("this", prop_this)
 810
 811    return expression
 812
 813
 814def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
 815    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
 816    if isinstance(expression, exp.Struct):
 817        expression.set(
 818            "expressions",
 819            [
 820                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
 821                for e in expression.expressions
 822            ],
 823        )
 824
 825    return expression
 826
 827
 828def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
 829    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
 830
 831    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
 832
 833    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
 834
 835    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
 836
 837    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
 838
 839    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
 840
 841    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
 842
 843    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
 844
 845    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
 846
 847    -- example with WHERE
 848    SELECT d.department_name, sum(e.salary) as total_salary
 849    FROM departments d, employees e
 850    WHERE e.department_id(+) = d.department_id
 851    group by department_name
 852
 853    -- example of left correlation in select
 854    SELECT d.department_name, (
 855        SELECT SUM(e.salary)
 856            FROM employees e
 857            WHERE e.department_id(+) = d.department_id) AS total_salary
 858    FROM departments d;
 859
 860    -- example of left correlation in from
 861    SELECT d.department_name, t.total_salary
 862    FROM departments d, (
 863            SELECT SUM(e.salary) AS total_salary
 864            FROM employees e
 865            WHERE e.department_id(+) = d.department_id
 866        ) t
 867    """
 868
 869    from sqlglot.optimizer.scope import traverse_scope
 870    from sqlglot.optimizer.normalize import normalize, normalized
 871    from collections import defaultdict
 872
 873    # we go in reverse to check the main query for left correlation
 874    for scope in reversed(traverse_scope(expression)):
 875        query = scope.expression
 876
 877        where = query.args.get("where")
 878        joins = query.args.get("joins", [])
 879
 880        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
 881            continue
 882
 883        # knockout: we do not support left correlation (see point 2)
 884        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
 885
 886        # make sure we have AND of ORs to have clear join terms
 887        where = normalize(where.this)
 888        assert normalized(where), "Cannot normalize JOIN predicates"
 889
 890        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
 891        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
 892            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
 893
 894            left_join_table = set(col.table for col in join_cols)
 895            if not left_join_table:
 896                continue
 897
 898            assert not (
 899                len(left_join_table) > 1
 900            ), "Cannot combine JOIN predicates from different tables"
 901
 902            for col in join_cols:
 903                col.set("join_mark", False)
 904
 905            joins_ons[left_join_table.pop()].append(cond)
 906
 907        old_joins = {join.alias_or_name: join for join in joins}
 908        new_joins = {}
 909        query_from = query.args["from_"]
 910
 911        for table, predicates in joins_ons.items():
 912            join_what = old_joins.get(table, query_from).this.copy()
 913            new_joins[join_what.alias_or_name] = exp.Join(
 914                this=join_what, on=exp.and_(*predicates), kind="LEFT"
 915            )
 916
 917            for p in predicates:
 918                while isinstance(p.parent, exp.Paren):
 919                    p.parent.replace(p)
 920
 921                parent = p.parent
 922                p.pop()
 923                if isinstance(parent, exp.Binary):
 924                    parent.replace(parent.right if parent.left is None else parent.left)
 925                elif isinstance(parent, exp.Where):
 926                    parent.pop()
 927
 928        if query_from.alias_or_name in new_joins:
 929            only_old_joins = old_joins.keys() - new_joins.keys()
 930            assert (
 931                len(only_old_joins) >= 1
 932            ), "Cannot determine which table to use in the new FROM clause"
 933
 934            new_from_name = list(only_old_joins)[0]
 935            query.set("from_", exp.From(this=old_joins[new_from_name].this))
 936
 937        if new_joins:
 938            for n, j in old_joins.items():  # preserve any other joins
 939                if n not in new_joins and n != query.args["from_"].name:
 940                    if not j.kind:
 941                        j.set("kind", "CROSS")
 942                    new_joins[n] = j
 943            query.set("joins", list(new_joins.values()))
 944
 945    return expression
 946
 947
 948def any_to_exists(expression: exp.Expression) -> exp.Expression:
 949    """
 950    Transform ANY operator to Spark's EXISTS
 951
 952    For example,
 953        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
 954        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
 955
 956    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
 957    transformation
 958    """
 959    if isinstance(expression, exp.Select):
 960        for any_expr in expression.find_all(exp.Any):
 961            this = any_expr.this
 962            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
 963                continue
 964
 965            binop = any_expr.parent
 966            if isinstance(binop, exp.Binary):
 967                lambda_arg = exp.to_identifier("x")
 968                any_expr.replace(lambda_arg)
 969                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
 970                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
 971
 972    return expression
 973
 974
 975def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
 976    """Eliminates the `WINDOW` query clause by inling each named window."""
 977    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 978        from sqlglot.optimizer.scope import find_all_in_scope
 979
 980        windows = expression.args["windows"]
 981        expression.set("windows", None)
 982
 983        window_expression: t.Dict[str, exp.Expression] = {}
 984
 985        def _inline_inherited_window(window: exp.Expression) -> None:
 986            inherited_window = window_expression.get(window.alias.lower())
 987            if not inherited_window:
 988                return
 989
 990            window.set("alias", None)
 991            for key in ("partition_by", "order", "spec"):
 992                arg = inherited_window.args.get(key)
 993                if arg:
 994                    window.set(key, arg.copy())
 995
 996        for window in windows:
 997            _inline_inherited_window(window)
 998            window_expression[window.name.lower()] = window
 999
1000        for window in find_all_in_scope(expression, exp.Window):
1001            _inline_inherited_window(window)
1002
1003    return expression
1004
1005
1006def inherit_struct_field_names(expression: exp.Expression) -> exp.Expression:
1007    """
1008    Inherit field names from the first struct in an array.
1009
1010    BigQuery supports implicitly inheriting names from the first STRUCT in an array:
1011
1012    Example:
1013        ARRAY[
1014          STRUCT('Alice' AS name, 85 AS score),  -- defines names
1015          STRUCT('Bob', 92),                     -- inherits names
1016          STRUCT('Diana', 95)                    -- inherits names
1017        ]
1018
1019    This transformation makes the field names explicit on all structs by adding
1020    PropertyEQ nodes, in order to facilitate transpilation to other dialects.
1021
1022    Args:
1023        expression: The expression tree to transform
1024
1025    Returns:
1026        The modified expression with field names inherited in all structs
1027    """
1028    if (
1029        isinstance(expression, exp.Array)
1030        and expression.args.get("struct_name_inheritance")
1031        and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct)
1032        and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions)
1033    ):
1034        field_names = [fld.this for fld in first_item.expressions]
1035
1036        # Apply field names to subsequent structs that don't have them
1037        for struct in expression.expressions[1:]:
1038            if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names):
1039                continue
1040
1041            # Convert unnamed expressions to PropertyEQ with inherited names
1042            new_expressions = []
1043            for i, expr in enumerate(struct.expressions):
1044                if not isinstance(expr, exp.PropertyEQ):
1045                    # Create PropertyEQ: field_name := value
1046                    new_expressions.append(
1047                        exp.PropertyEQ(
1048                            this=exp.Identifier(this=field_names[i].copy()), expression=expr
1049                        )
1050                    )
1051                else:
1052                    new_expressions.append(expr)
1053
1054            struct.set("expressions", new_expressions)
1055
1056    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]], generator: Optional[Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18    generator: t.Optional[t.Callable[[Generator, exp.Expression], str]] = None,
19) -> t.Callable[[Generator, exp.Expression], str]:
20    """
21    Creates a new transform by chaining a sequence of transformations and converts the resulting
22    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
23    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
24
25    Args:
26        transforms: sequence of transform functions. These will be called in order.
27
28    Returns:
29        Function that can be used as a generator transform.
30    """
31
32    def _to_sql(self, expression: exp.Expression) -> str:
33        expression_type = type(expression)
34
35        try:
36            expression = transforms[0](expression)
37            for transform in transforms[1:]:
38                expression = transform(expression)
39        except UnsupportedError as unsupported_error:
40            self.unsupported(str(unsupported_error))
41
42        if generator:
43            return generator(self, expression)
44
45        _sql_handler = getattr(self, expression.key + "_sql", None)
46        if _sql_handler:
47            return _sql_handler(expression)
48
49        transforms_handler = self.TRANSFORMS.get(type(expression))
50        if transforms_handler:
51            if expression_type is type(expression):
52                if isinstance(expression, exp.Func):
53                    return self.function_fallback_sql(expression)
54
55                # Ensures we don't enter an infinite loop. This can happen when the original expression
56                # has the same type as the final expression and there's no _sql method available for it,
57                # because then it'd re-enter _to_sql.
58                raise ValueError(
59                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
60                )
61
62            return transforms_handler(self, expression)
63
64        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
65
66    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 69def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 70    if isinstance(expression, exp.Select):
 71        count = 0
 72        recursive_ctes = []
 73
 74        for unnest in expression.find_all(exp.Unnest):
 75            if (
 76                not isinstance(unnest.parent, (exp.From, exp.Join))
 77                or len(unnest.expressions) != 1
 78                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 79            ):
 80                continue
 81
 82            generate_date_array = unnest.expressions[0]
 83            start = generate_date_array.args.get("start")
 84            end = generate_date_array.args.get("end")
 85            step = generate_date_array.args.get("step")
 86
 87            if not start or not end or not isinstance(step, exp.Interval):
 88                continue
 89
 90            alias = unnest.args.get("alias")
 91            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 92
 93            start = exp.cast(start, "date")
 94            date_add = exp.func(
 95                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 96            )
 97            cast_date_add = exp.cast(date_add, "date")
 98
 99            cte_name = "_generated_dates" + (f"_{count}" if count else "")
100
101            base_query = exp.select(start.as_(column_name))
102            recursive_query = (
103                exp.select(cast_date_add)
104                .from_(cte_name)
105                .where(cast_date_add <= exp.cast(end, "date"))
106            )
107            cte_query = base_query.union(recursive_query, distinct=False)
108
109            generate_dates_query = exp.select(column_name).from_(cte_name)
110            unnest.replace(generate_dates_query.subquery(cte_name))
111
112            recursive_ctes.append(
113                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
114            )
115            count += 1
116
117        if recursive_ctes:
118            with_expression = expression.args.get("with_") or exp.With()
119            with_expression.set("recursive", True)
120            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
121            expression.set("with_", with_expression)
122
123    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
126def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
127    """Unnests GENERATE_SERIES or SEQUENCE table references."""
128    this = expression.this
129    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
130        unnest = exp.Unnest(expressions=[this])
131        if expression.alias:
132            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
133
134        return unnest
135
136    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
139def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
140    """
141    Convert SELECT DISTINCT ON statements to a subquery with a window function.
142
143    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
144
145    Args:
146        expression: the expression that will be transformed.
147
148    Returns:
149        The transformed expression.
150    """
151    if (
152        isinstance(expression, exp.Select)
153        and expression.args.get("distinct")
154        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
155    ):
156        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
157
158        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
159        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
160
161        order = expression.args.get("order")
162        if order:
163            window.set("order", order.pop())
164        else:
165            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
166
167        window = exp.alias_(window, row_number_window_alias)
168        expression.select(window, copy=False)
169
170        # We add aliases to the projections so that we can safely reference them in the outer query
171        new_selects = []
172        taken_names = {row_number_window_alias}
173        for select in expression.selects[:-1]:
174            if select.is_star:
175                new_selects = [exp.Star()]
176                break
177
178            if not isinstance(select, exp.Alias):
179                alias = find_new_name(taken_names, select.output_name or "_col")
180                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
181                select = select.replace(exp.alias_(select, alias, quoted=quoted))
182
183            taken_names.add(select.output_name)
184            new_selects.append(select.args["alias"])
185
186        return (
187            exp.select(*new_selects, copy=False)
188            .from_(expression.subquery("_t", copy=False), copy=False)
189            .where(exp.column(row_number_window_alias).eq(1), copy=False)
190        )
191
192    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
195def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
196    """
197    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
198
199    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
200    https://docs.snowflake.com/en/sql-reference/constructs/qualify
201
202    Some dialects don't support window functions in the WHERE clause, so we need to include them as
203    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
204    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
205    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
206    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
207    corresponding expression to avoid creating invalid column references.
208    """
209    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
210        taken = set(expression.named_selects)
211        for select in expression.selects:
212            if not select.alias_or_name:
213                alias = find_new_name(taken, "_c")
214                select.replace(exp.alias_(select, alias))
215                taken.add(alias)
216
217        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
218            alias_or_name = select.alias_or_name
219            identifier = select.args.get("alias") or select.this
220            if isinstance(identifier, exp.Identifier):
221                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
222            return alias_or_name
223
224        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
225        qualify_filters = expression.args["qualify"].pop().this
226        expression_by_alias = {
227            select.alias: select.this
228            for select in expression.selects
229            if isinstance(select, exp.Alias)
230        }
231
232        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
233        for select_candidate in list(qualify_filters.find_all(select_candidates)):
234            if isinstance(select_candidate, exp.Window):
235                if expression_by_alias:
236                    for column in select_candidate.find_all(exp.Column):
237                        expr = expression_by_alias.get(column.name)
238                        if expr:
239                            column.replace(expr)
240
241                alias = find_new_name(expression.named_selects, "_w")
242                expression.select(exp.alias_(select_candidate, alias), copy=False)
243                column = exp.column(alias)
244
245                if isinstance(select_candidate.parent, exp.Qualify):
246                    qualify_filters = column
247                else:
248                    select_candidate.replace(column)
249            elif select_candidate.name not in expression.named_selects:
250                expression.select(select_candidate.copy(), copy=False)
251
252        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
253            qualify_filters, copy=False
254        )
255
256    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
259def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
260    """
261    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
262    other expressions. This transforms removes the precision from parameterized types in expressions.
263    """
264    for node in expression.find_all(exp.DataType):
265        node.set(
266            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
267        )
268
269    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
272def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
273    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
274    from sqlglot.optimizer.scope import find_all_in_scope
275
276    if isinstance(expression, exp.Select):
277        unnest_aliases = {
278            unnest.alias
279            for unnest in find_all_in_scope(expression, exp.Unnest)
280            if isinstance(unnest.parent, (exp.From, exp.Join))
281        }
282        if unnest_aliases:
283            for column in expression.find_all(exp.Column):
284                leftmost_part = column.parts[0]
285                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
286                    leftmost_part.pop()
287
288    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
291def unnest_to_explode(
292    expression: exp.Expression,
293    unnest_using_arrays_zip: bool = True,
294) -> exp.Expression:
295    """Convert cross join unnest into lateral view explode."""
296
297    def _unnest_zip_exprs(
298        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
299    ) -> t.List[exp.Expression]:
300        if has_multi_expr:
301            if not unnest_using_arrays_zip:
302                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
303
304            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
305            zip_exprs: t.List[exp.Expression] = [
306                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
307            ]
308            u.set("expressions", zip_exprs)
309            return zip_exprs
310        return unnest_exprs
311
312    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
313        if u.args.get("offset"):
314            return exp.Posexplode
315        return exp.Inline if has_multi_expr else exp.Explode
316
317    if isinstance(expression, exp.Select):
318        from_ = expression.args.get("from_")
319
320        if from_ and isinstance(from_.this, exp.Unnest):
321            unnest = from_.this
322            alias = unnest.args.get("alias")
323            exprs = unnest.expressions
324            has_multi_expr = len(exprs) > 1
325            this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
326
327            columns = alias.columns if alias else []
328            offset = unnest.args.get("offset")
329            if offset:
330                columns.insert(
331                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
332                )
333
334            unnest.replace(
335                exp.Table(
336                    this=_udtf_type(unnest, has_multi_expr)(this=this),
337                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
338                )
339            )
340
341        joins = expression.args.get("joins") or []
342        for join in list(joins):
343            join_expr = join.this
344
345            is_lateral = isinstance(join_expr, exp.Lateral)
346
347            unnest = join_expr.this if is_lateral else join_expr
348
349            if isinstance(unnest, exp.Unnest):
350                if is_lateral:
351                    alias = join_expr.args.get("alias")
352                else:
353                    alias = unnest.args.get("alias")
354                exprs = unnest.expressions
355                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
356                has_multi_expr = len(exprs) > 1
357                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
358
359                joins.remove(join)
360
361                alias_cols = alias.columns if alias else []
362
363                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
364                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
365                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
366
367                if not has_multi_expr and len(alias_cols) not in (1, 2):
368                    raise UnsupportedError(
369                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
370                    )
371
372                offset = unnest.args.get("offset")
373                if offset:
374                    alias_cols.insert(
375                        0,
376                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
377                    )
378
379                for e, column in zip(exprs, alias_cols):
380                    expression.append(
381                        "laterals",
382                        exp.Lateral(
383                            this=_udtf_type(unnest, has_multi_expr)(this=e),
384                            view=True,
385                            alias=exp.TableAlias(
386                                this=alias.this,  # type: ignore
387                                columns=alias_cols,
388                            ),
389                        ),
390                    )
391
392    return expression

Convert cross join unnest into lateral view explode.

def explode_projection_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
395def explode_projection_to_unnest(
396    index_offset: int = 0,
397) -> t.Callable[[exp.Expression], exp.Expression]:
398    """Convert explode/posexplode projections into unnests."""
399
400    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
401        if isinstance(expression, exp.Select):
402            from sqlglot.optimizer.scope import Scope
403
404            taken_select_names = set(expression.named_selects)
405            taken_source_names = {name for name, _ in Scope(expression).references}
406
407            def new_name(names: t.Set[str], name: str) -> str:
408                name = find_new_name(names, name)
409                names.add(name)
410                return name
411
412            arrays: t.List[exp.Condition] = []
413            series_alias = new_name(taken_select_names, "pos")
414            series = exp.alias_(
415                exp.Unnest(
416                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
417                ),
418                new_name(taken_source_names, "_u"),
419                table=[series_alias],
420            )
421
422            # we use list here because expression.selects is mutated inside the loop
423            for select in list(expression.selects):
424                explode = select.find(exp.Explode)
425
426                if explode:
427                    pos_alias = ""
428                    explode_alias = ""
429
430                    if isinstance(select, exp.Alias):
431                        explode_alias = select.args["alias"]
432                        alias = select
433                    elif isinstance(select, exp.Aliases):
434                        pos_alias = select.aliases[0]
435                        explode_alias = select.aliases[1]
436                        alias = select.replace(exp.alias_(select.this, "", copy=False))
437                    else:
438                        alias = select.replace(exp.alias_(select, ""))
439                        explode = alias.find(exp.Explode)
440                        assert explode
441
442                    is_posexplode = isinstance(explode, exp.Posexplode)
443                    explode_arg = explode.this
444
445                    if isinstance(explode, exp.ExplodeOuter):
446                        bracket = explode_arg[0]
447                        bracket.set("safe", True)
448                        bracket.set("offset", True)
449                        explode_arg = exp.func(
450                            "IF",
451                            exp.func(
452                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
453                            ).eq(0),
454                            exp.array(bracket, copy=False),
455                            explode_arg,
456                        )
457
458                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
459                    if isinstance(explode_arg, exp.Column):
460                        taken_select_names.add(explode_arg.output_name)
461
462                    unnest_source_alias = new_name(taken_source_names, "_u")
463
464                    if not explode_alias:
465                        explode_alias = new_name(taken_select_names, "col")
466
467                        if is_posexplode:
468                            pos_alias = new_name(taken_select_names, "pos")
469
470                    if not pos_alias:
471                        pos_alias = new_name(taken_select_names, "pos")
472
473                    alias.set("alias", exp.to_identifier(explode_alias))
474
475                    series_table_alias = series.args["alias"].this
476                    column = exp.If(
477                        this=exp.column(series_alias, table=series_table_alias).eq(
478                            exp.column(pos_alias, table=unnest_source_alias)
479                        ),
480                        true=exp.column(explode_alias, table=unnest_source_alias),
481                    )
482
483                    explode.replace(column)
484
485                    if is_posexplode:
486                        expressions = expression.expressions
487                        expressions.insert(
488                            expressions.index(alias) + 1,
489                            exp.If(
490                                this=exp.column(series_alias, table=series_table_alias).eq(
491                                    exp.column(pos_alias, table=unnest_source_alias)
492                                ),
493                                true=exp.column(pos_alias, table=unnest_source_alias),
494                            ).as_(pos_alias),
495                        )
496                        expression.set("expressions", expressions)
497
498                    if not arrays:
499                        if expression.args.get("from_"):
500                            expression.join(series, copy=False, join_type="CROSS")
501                        else:
502                            expression.from_(series, copy=False)
503
504                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
505                    arrays.append(size)
506
507                    # trino doesn't support left join unnest with on conditions
508                    # if it did, this would be much simpler
509                    expression.join(
510                        exp.alias_(
511                            exp.Unnest(
512                                expressions=[explode_arg.copy()],
513                                offset=exp.to_identifier(pos_alias),
514                            ),
515                            unnest_source_alias,
516                            table=[explode_alias],
517                        ),
518                        join_type="CROSS",
519                        copy=False,
520                    )
521
522                    if index_offset != 1:
523                        size = size - 1
524
525                    expression.where(
526                        exp.column(series_alias, table=series_table_alias)
527                        .eq(exp.column(pos_alias, table=unnest_source_alias))
528                        .or_(
529                            (exp.column(series_alias, table=series_table_alias) > size).and_(
530                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
531                            )
532                        ),
533                        copy=False,
534                    )
535
536            if arrays:
537                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
538
539                if index_offset != 1:
540                    end = end - (1 - index_offset)
541                series.expressions[0].set("end", end)
542
543        return expression
544
545    return _explode_projection_to_unnest

Convert explode/posexplode projections into unnests.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
548def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
549    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
550    if (
551        isinstance(expression, exp.PERCENTILES)
552        and not isinstance(expression.parent, exp.WithinGroup)
553        and expression.expression
554    ):
555        column = expression.this.pop()
556        expression.set("this", expression.expression.pop())
557        order = exp.Order(expressions=[exp.Ordered(this=column)])
558        expression = exp.WithinGroup(this=expression, expression=order)
559
560    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
563def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
564    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
565    if (
566        isinstance(expression, exp.WithinGroup)
567        and isinstance(expression.this, exp.PERCENTILES)
568        and isinstance(expression.expression, exp.Order)
569    ):
570        quantile = expression.this.this
571        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
572        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
573
574    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
577def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
578    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
579    if isinstance(expression, exp.With) and expression.recursive:
580        next_name = name_sequence("_c_")
581
582        for cte in expression.expressions:
583            if not cte.args["alias"].columns:
584                query = cte.this
585                if isinstance(query, exp.SetOperation):
586                    query = query.this
587
588                cte.args["alias"].set(
589                    "columns",
590                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
591                )
592
593    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
596def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
597    """Replace 'epoch' in casts by the equivalent date literal."""
598    if (
599        isinstance(expression, (exp.Cast, exp.TryCast))
600        and expression.name.lower() == "epoch"
601        and expression.to.this in exp.DataType.TEMPORAL_TYPES
602    ):
603        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
604
605    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
608def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
609    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
610    if isinstance(expression, exp.Select):
611        for join in expression.args.get("joins") or []:
612            on = join.args.get("on")
613            if on and join.kind in ("SEMI", "ANTI"):
614                subquery = exp.select("1").from_(join.this).where(on)
615                exists = exp.Exists(this=subquery)
616                if join.kind == "ANTI":
617                    exists = exists.not_(copy=False)
618
619                join.pop()
620                expression.where(exists, copy=False)
621
622    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
625def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
626    """
627    Converts a query with a FULL OUTER join to a union of identical queries that
628    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
629    for queries that have a single FULL OUTER join.
630    """
631    if isinstance(expression, exp.Select):
632        full_outer_joins = [
633            (index, join)
634            for index, join in enumerate(expression.args.get("joins") or [])
635            if join.side == "FULL"
636        ]
637
638        if len(full_outer_joins) == 1:
639            expression_copy = expression.copy()
640            expression.set("limit", None)
641            index, full_outer_join = full_outer_joins[0]
642
643            tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name)
644            join_conditions = full_outer_join.args.get("on") or exp.and_(
645                *[
646                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
647                    for col in full_outer_join.args.get("using")
648                ]
649            )
650
651            full_outer_join.set("side", "left")
652            anti_join_clause = (
653                exp.select("1").from_(expression.args["from_"]).where(join_conditions)
654            )
655            expression_copy.args["joins"][index].set("side", "right")
656            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
657            expression_copy.set("with_", None)  # remove CTEs from RIGHT side
658            expression.set("order", None)  # remove order by from LEFT side
659
660            return exp.union(expression, expression_copy, copy=False, distinct=False)
661
662    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
665def move_ctes_to_top_level(expression: E) -> E:
666    """
667    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
668    defined at the top-level, so for example queries like:
669
670        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
671
672    are invalid in those dialects. This transformation can be used to ensure all CTEs are
673    moved to the top level so that the final SQL code is valid from a syntax standpoint.
674
675    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
676    """
677    top_level_with = expression.args.get("with_")
678    for inner_with in expression.find_all(exp.With):
679        if inner_with.parent is expression:
680            continue
681
682        if not top_level_with:
683            top_level_with = inner_with.pop()
684            expression.set("with_", top_level_with)
685        else:
686            if inner_with.recursive:
687                top_level_with.set("recursive", True)
688
689            parent_cte = inner_with.find_ancestor(exp.CTE)
690            inner_with.pop()
691
692            if parent_cte:
693                i = top_level_with.expressions.index(parent_cte)
694                top_level_with.expressions[i:i] = inner_with.expressions
695                top_level_with.set("expressions", top_level_with.expressions)
696            else:
697                top_level_with.set(
698                    "expressions", top_level_with.expressions + inner_with.expressions
699                )
700
701    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
704def ensure_bools(expression: exp.Expression) -> exp.Expression:
705    """Converts numeric values used in conditions into explicit boolean expressions."""
706    from sqlglot.optimizer.canonicalize import ensure_bools
707
708    def _ensure_bool(node: exp.Expression) -> None:
709        if (
710            node.is_number
711            or (
712                not isinstance(node, exp.SubqueryPredicate)
713                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
714            )
715            or (isinstance(node, exp.Column) and not node.type)
716        ):
717            node.replace(node.neq(0))
718
719    for node in expression.walk():
720        ensure_bools(node, _ensure_bool)
721
722    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
725def unqualify_columns(expression: exp.Expression) -> exp.Expression:
726    for column in expression.find_all(exp.Column):
727        # We only wanna pop off the table, db, catalog args
728        for part in column.parts[:-1]:
729            part.pop()
730
731    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
734def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
735    assert isinstance(expression, exp.Create)
736    for constraint in expression.find_all(exp.UniqueColumnConstraint):
737        if constraint.parent:
738            constraint.parent.pop()
739
740    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
743def ctas_with_tmp_tables_to_create_tmp_view(
744    expression: exp.Expression,
745    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
746) -> exp.Expression:
747    assert isinstance(expression, exp.Create)
748    properties = expression.args.get("properties")
749    temporary = any(
750        isinstance(prop, exp.TemporaryProperty)
751        for prop in (properties.expressions if properties else [])
752    )
753
754    # CTAS with temp tables map to CREATE TEMPORARY VIEW
755    if expression.kind == "TABLE" and temporary:
756        if expression.expression:
757            return exp.Create(
758                kind="TEMPORARY VIEW",
759                this=expression.this,
760                expression=expression.expression,
761            )
762        return tmp_storage_provider(expression)
763
764    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
767def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
768    """
769    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
770    PARTITIONED BY value is an array of column names, they are transformed into a schema.
771    The corresponding columns are removed from the create statement.
772    """
773    assert isinstance(expression, exp.Create)
774    has_schema = isinstance(expression.this, exp.Schema)
775    is_partitionable = expression.kind in {"TABLE", "VIEW"}
776
777    if has_schema and is_partitionable:
778        prop = expression.find(exp.PartitionedByProperty)
779        if prop and prop.this and not isinstance(prop.this, exp.Schema):
780            schema = expression.this
781            columns = {v.name.upper() for v in prop.this.expressions}
782            partitions = [col for col in schema.expressions if col.name.upper() in columns]
783            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
784            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
785            expression.set("this", schema)
786
787    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
790def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
791    """
792    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
793
794    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
795    """
796    assert isinstance(expression, exp.Create)
797    prop = expression.find(exp.PartitionedByProperty)
798    if (
799        prop
800        and prop.this
801        and isinstance(prop.this, exp.Schema)
802        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
803    ):
804        prop_this = exp.Tuple(
805            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
806        )
807        schema = expression.this
808        for e in prop.this.expressions:
809            schema.append("expressions", e)
810        prop.set("this", prop_this)
811
812    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
815def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
816    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
817    if isinstance(expression, exp.Struct):
818        expression.set(
819            "expressions",
820            [
821                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
822                for e in expression.expressions
823            ],
824        )
825
826    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
829def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
830    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
831
832    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
833
834    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
835
836    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
837
838    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
839
840    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
841
842    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
843
844    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
845
846    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
847
848    -- example with WHERE
849    SELECT d.department_name, sum(e.salary) as total_salary
850    FROM departments d, employees e
851    WHERE e.department_id(+) = d.department_id
852    group by department_name
853
854    -- example of left correlation in select
855    SELECT d.department_name, (
856        SELECT SUM(e.salary)
857            FROM employees e
858            WHERE e.department_id(+) = d.department_id) AS total_salary
859    FROM departments d;
860
861    -- example of left correlation in from
862    SELECT d.department_name, t.total_salary
863    FROM departments d, (
864            SELECT SUM(e.salary) AS total_salary
865            FROM employees e
866            WHERE e.department_id(+) = d.department_id
867        ) t
868    """
869
870    from sqlglot.optimizer.scope import traverse_scope
871    from sqlglot.optimizer.normalize import normalize, normalized
872    from collections import defaultdict
873
874    # we go in reverse to check the main query for left correlation
875    for scope in reversed(traverse_scope(expression)):
876        query = scope.expression
877
878        where = query.args.get("where")
879        joins = query.args.get("joins", [])
880
881        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
882            continue
883
884        # knockout: we do not support left correlation (see point 2)
885        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
886
887        # make sure we have AND of ORs to have clear join terms
888        where = normalize(where.this)
889        assert normalized(where), "Cannot normalize JOIN predicates"
890
891        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
892        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
893            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
894
895            left_join_table = set(col.table for col in join_cols)
896            if not left_join_table:
897                continue
898
899            assert not (
900                len(left_join_table) > 1
901            ), "Cannot combine JOIN predicates from different tables"
902
903            for col in join_cols:
904                col.set("join_mark", False)
905
906            joins_ons[left_join_table.pop()].append(cond)
907
908        old_joins = {join.alias_or_name: join for join in joins}
909        new_joins = {}
910        query_from = query.args["from_"]
911
912        for table, predicates in joins_ons.items():
913            join_what = old_joins.get(table, query_from).this.copy()
914            new_joins[join_what.alias_or_name] = exp.Join(
915                this=join_what, on=exp.and_(*predicates), kind="LEFT"
916            )
917
918            for p in predicates:
919                while isinstance(p.parent, exp.Paren):
920                    p.parent.replace(p)
921
922                parent = p.parent
923                p.pop()
924                if isinstance(parent, exp.Binary):
925                    parent.replace(parent.right if parent.left is None else parent.left)
926                elif isinstance(parent, exp.Where):
927                    parent.pop()
928
929        if query_from.alias_or_name in new_joins:
930            only_old_joins = old_joins.keys() - new_joins.keys()
931            assert (
932                len(only_old_joins) >= 1
933            ), "Cannot determine which table to use in the new FROM clause"
934
935            new_from_name = list(only_old_joins)[0]
936            query.set("from_", exp.From(this=old_joins[new_from_name].this))
937
938        if new_joins:
939            for n, j in old_joins.items():  # preserve any other joins
940                if n not in new_joins and n != query.args["from_"].name:
941                    if not j.kind:
942                        j.set("kind", "CROSS")
943                    new_joins[n] = j
944            query.set("joins", list(new_joins.values()))
945
946    return expression

https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178

  1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.

  2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.

The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.

You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.

The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.

A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.

A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.

A WHERE condition cannot compare any column marked with the (+) operator with a subquery.

-- example with WHERE SELECT d.department_name, sum(e.salary) as total_salary FROM departments d, employees e WHERE e.department_id(+) = d.department_id group by department_name

-- example of left correlation in select SELECT d.department_name, ( SELECT SUM(e.salary) FROM employees e WHERE e.department_id(+) = d.department_id) AS total_salary FROM departments d;

-- example of left correlation in from SELECT d.department_name, t.total_salary FROM departments d, ( SELECT SUM(e.salary) AS total_salary FROM employees e WHERE e.department_id(+) = d.department_id ) t

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
949def any_to_exists(expression: exp.Expression) -> exp.Expression:
950    """
951    Transform ANY operator to Spark's EXISTS
952
953    For example,
954        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
955        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
956
957    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
958    transformation
959    """
960    if isinstance(expression, exp.Select):
961        for any_expr in expression.find_all(exp.Any):
962            this = any_expr.this
963            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
964                continue
965
966            binop = any_expr.parent
967            if isinstance(binop, exp.Binary):
968                lambda_arg = exp.to_identifier("x")
969                any_expr.replace(lambda_arg)
970                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
971                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
972
973    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation

def eliminate_window_clause( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 976def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
 977    """Eliminates the `WINDOW` query clause by inling each named window."""
 978    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 979        from sqlglot.optimizer.scope import find_all_in_scope
 980
 981        windows = expression.args["windows"]
 982        expression.set("windows", None)
 983
 984        window_expression: t.Dict[str, exp.Expression] = {}
 985
 986        def _inline_inherited_window(window: exp.Expression) -> None:
 987            inherited_window = window_expression.get(window.alias.lower())
 988            if not inherited_window:
 989                return
 990
 991            window.set("alias", None)
 992            for key in ("partition_by", "order", "spec"):
 993                arg = inherited_window.args.get(key)
 994                if arg:
 995                    window.set(key, arg.copy())
 996
 997        for window in windows:
 998            _inline_inherited_window(window)
 999            window_expression[window.name.lower()] = window
1000
1001        for window in find_all_in_scope(expression, exp.Window):
1002            _inline_inherited_window(window)
1003
1004    return expression

Eliminates the WINDOW query clause by inling each named window.

def inherit_struct_field_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1007def inherit_struct_field_names(expression: exp.Expression) -> exp.Expression:
1008    """
1009    Inherit field names from the first struct in an array.
1010
1011    BigQuery supports implicitly inheriting names from the first STRUCT in an array:
1012
1013    Example:
1014        ARRAY[
1015          STRUCT('Alice' AS name, 85 AS score),  -- defines names
1016          STRUCT('Bob', 92),                     -- inherits names
1017          STRUCT('Diana', 95)                    -- inherits names
1018        ]
1019
1020    This transformation makes the field names explicit on all structs by adding
1021    PropertyEQ nodes, in order to facilitate transpilation to other dialects.
1022
1023    Args:
1024        expression: The expression tree to transform
1025
1026    Returns:
1027        The modified expression with field names inherited in all structs
1028    """
1029    if (
1030        isinstance(expression, exp.Array)
1031        and expression.args.get("struct_name_inheritance")
1032        and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct)
1033        and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions)
1034    ):
1035        field_names = [fld.this for fld in first_item.expressions]
1036
1037        # Apply field names to subsequent structs that don't have them
1038        for struct in expression.expressions[1:]:
1039            if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names):
1040                continue
1041
1042            # Convert unnamed expressions to PropertyEQ with inherited names
1043            new_expressions = []
1044            for i, expr in enumerate(struct.expressions):
1045                if not isinstance(expr, exp.PropertyEQ):
1046                    # Create PropertyEQ: field_name := value
1047                    new_expressions.append(
1048                        exp.PropertyEQ(
1049                            this=exp.Identifier(this=field_names[i].copy()), expression=expr
1050                        )
1051                    )
1052                else:
1053                    new_expressions.append(expr)
1054
1055            struct.set("expressions", new_expressions)
1056
1057    return expression

Inherit field names from the first struct in an array.

BigQuery supports implicitly inheriting names from the first STRUCT in an array:

Example:

ARRAY[ STRUCT('Alice' AS name, 85 AS score), -- defines names STRUCT('Bob', 92), -- inherits names STRUCT('Diana', 95) -- inherits names ]

This transformation makes the field names explicit on all structs by adding PropertyEQ nodes, in order to facilitate transpilation to other dialects.

Arguments:
  • expression: The expression tree to transform
Returns:

The modified expression with field names inherited in all structs