Edit on GitHub

sqlglot.transforms

   1from __future__ import annotations
   2
   3import typing as t
   4
   5from sqlglot import expressions as exp
   6from sqlglot.errors import UnsupportedError
   7from sqlglot.helper import find_new_name, name_sequence
   8
   9
  10if t.TYPE_CHECKING:
  11    from sqlglot._typing import E
  12    from sqlglot.generator import Generator
  13
  14
  15def preprocess(
  16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
  17) -> t.Callable[[Generator, exp.Expression], str]:
  18    """
  19    Creates a new transform by chaining a sequence of transformations and converts the resulting
  20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
  21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
  22
  23    Args:
  24        transforms: sequence of transform functions. These will be called in order.
  25
  26    Returns:
  27        Function that can be used as a generator transform.
  28    """
  29
  30    def _to_sql(self, expression: exp.Expression) -> str:
  31        expression_type = type(expression)
  32
  33        try:
  34            expression = transforms[0](expression)
  35            for transform in transforms[1:]:
  36                expression = transform(expression)
  37        except UnsupportedError as unsupported_error:
  38            self.unsupported(str(unsupported_error))
  39
  40        _sql_handler = getattr(self, expression.key + "_sql", None)
  41        if _sql_handler:
  42            return _sql_handler(expression)
  43
  44        transforms_handler = self.TRANSFORMS.get(type(expression))
  45        if transforms_handler:
  46            if expression_type is type(expression):
  47                if isinstance(expression, exp.Func):
  48                    return self.function_fallback_sql(expression)
  49
  50                # Ensures we don't enter an infinite loop. This can happen when the original expression
  51                # has the same type as the final expression and there's no _sql method available for it,
  52                # because then it'd re-enter _to_sql.
  53                raise ValueError(
  54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
  55                )
  56
  57            return transforms_handler(self, expression)
  58
  59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
  60
  61    return _to_sql
  62
  63
  64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
  65    if isinstance(expression, exp.Select):
  66        count = 0
  67        recursive_ctes = []
  68
  69        for unnest in expression.find_all(exp.Unnest):
  70            if (
  71                not isinstance(unnest.parent, (exp.From, exp.Join))
  72                or len(unnest.expressions) != 1
  73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
  74            ):
  75                continue
  76
  77            generate_date_array = unnest.expressions[0]
  78            start = generate_date_array.args.get("start")
  79            end = generate_date_array.args.get("end")
  80            step = generate_date_array.args.get("step")
  81
  82            if not start or not end or not isinstance(step, exp.Interval):
  83                continue
  84
  85            alias = unnest.args.get("alias")
  86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
  87
  88            start = exp.cast(start, "date")
  89            date_add = exp.func(
  90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
  91            )
  92            cast_date_add = exp.cast(date_add, "date")
  93
  94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
  95
  96            base_query = exp.select(start.as_(column_name))
  97            recursive_query = (
  98                exp.select(cast_date_add)
  99                .from_(cte_name)
 100                .where(cast_date_add <= exp.cast(end, "date"))
 101            )
 102            cte_query = base_query.union(recursive_query, distinct=False)
 103
 104            generate_dates_query = exp.select(column_name).from_(cte_name)
 105            unnest.replace(generate_dates_query.subquery(cte_name))
 106
 107            recursive_ctes.append(
 108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
 109            )
 110            count += 1
 111
 112        if recursive_ctes:
 113            with_expression = expression.args.get("with") or exp.With()
 114            with_expression.set("recursive", True)
 115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
 116            expression.set("with", with_expression)
 117
 118    return expression
 119
 120
 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
 122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
 123    this = expression.this
 124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
 125        unnest = exp.Unnest(expressions=[this])
 126        if expression.alias:
 127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
 128
 129        return unnest
 130
 131    return expression
 132
 133
 134def unalias_group(expression: exp.Expression) -> exp.Expression:
 135    """
 136    Replace references to select aliases in GROUP BY clauses.
 137
 138    Example:
 139        >>> import sqlglot
 140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 141        'SELECT a AS b FROM x GROUP BY 1'
 142
 143    Args:
 144        expression: the expression that will be transformed.
 145
 146    Returns:
 147        The transformed expression.
 148    """
 149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 150        aliased_selects = {
 151            e.alias: i
 152            for i, e in enumerate(expression.parent.expressions, start=1)
 153            if isinstance(e, exp.Alias)
 154        }
 155
 156        for group_by in expression.expressions:
 157            if (
 158                isinstance(group_by, exp.Column)
 159                and not group_by.table
 160                and group_by.name in aliased_selects
 161            ):
 162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 163
 164    return expression
 165
 166
 167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 168    """
 169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 170
 171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 172
 173    Args:
 174        expression: the expression that will be transformed.
 175
 176    Returns:
 177        The transformed expression.
 178    """
 179    if (
 180        isinstance(expression, exp.Select)
 181        and expression.args.get("distinct")
 182        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
 183    ):
 184        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
 185
 186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 187        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 188
 189        order = expression.args.get("order")
 190        if order:
 191            window.set("order", order.pop())
 192        else:
 193            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 194
 195        window = exp.alias_(window, row_number_window_alias)
 196        expression.select(window, copy=False)
 197
 198        # We add aliases to the projections so that we can safely reference them in the outer query
 199        new_selects = []
 200        taken_names = {row_number_window_alias}
 201        for select in expression.selects[:-1]:
 202            if select.is_star:
 203                new_selects = [exp.Star()]
 204                break
 205
 206            if not isinstance(select, exp.Alias):
 207                alias = find_new_name(taken_names, select.output_name or "_col")
 208                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
 209                select = select.replace(exp.alias_(select, alias, quoted=quoted))
 210
 211            taken_names.add(select.output_name)
 212            new_selects.append(select.args["alias"])
 213
 214        return (
 215            exp.select(*new_selects, copy=False)
 216            .from_(expression.subquery("_t", copy=False), copy=False)
 217            .where(exp.column(row_number_window_alias).eq(1), copy=False)
 218        )
 219
 220    return expression
 221
 222
 223def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 224    """
 225    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 226
 227    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 228    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 229
 230    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 231    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 232    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 233    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 234    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 235    corresponding expression to avoid creating invalid column references.
 236    """
 237    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 238        taken = set(expression.named_selects)
 239        for select in expression.selects:
 240            if not select.alias_or_name:
 241                alias = find_new_name(taken, "_c")
 242                select.replace(exp.alias_(select, alias))
 243                taken.add(alias)
 244
 245        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
 246            alias_or_name = select.alias_or_name
 247            identifier = select.args.get("alias") or select.this
 248            if isinstance(identifier, exp.Identifier):
 249                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
 250            return alias_or_name
 251
 252        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
 253        qualify_filters = expression.args["qualify"].pop().this
 254        expression_by_alias = {
 255            select.alias: select.this
 256            for select in expression.selects
 257            if isinstance(select, exp.Alias)
 258        }
 259
 260        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
 261        for select_candidate in list(qualify_filters.find_all(select_candidates)):
 262            if isinstance(select_candidate, exp.Window):
 263                if expression_by_alias:
 264                    for column in select_candidate.find_all(exp.Column):
 265                        expr = expression_by_alias.get(column.name)
 266                        if expr:
 267                            column.replace(expr)
 268
 269                alias = find_new_name(expression.named_selects, "_w")
 270                expression.select(exp.alias_(select_candidate, alias), copy=False)
 271                column = exp.column(alias)
 272
 273                if isinstance(select_candidate.parent, exp.Qualify):
 274                    qualify_filters = column
 275                else:
 276                    select_candidate.replace(column)
 277            elif select_candidate.name not in expression.named_selects:
 278                expression.select(select_candidate.copy(), copy=False)
 279
 280        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
 281            qualify_filters, copy=False
 282        )
 283
 284    return expression
 285
 286
 287def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
 288    """
 289    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
 290    other expressions. This transforms removes the precision from parameterized types in expressions.
 291    """
 292    for node in expression.find_all(exp.DataType):
 293        node.set(
 294            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
 295        )
 296
 297    return expression
 298
 299
 300def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
 301    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
 302    from sqlglot.optimizer.scope import find_all_in_scope
 303
 304    if isinstance(expression, exp.Select):
 305        unnest_aliases = {
 306            unnest.alias
 307            for unnest in find_all_in_scope(expression, exp.Unnest)
 308            if isinstance(unnest.parent, (exp.From, exp.Join))
 309        }
 310        if unnest_aliases:
 311            for column in expression.find_all(exp.Column):
 312                leftmost_part = column.parts[0]
 313                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
 314                    leftmost_part.pop()
 315
 316    return expression
 317
 318
 319def unnest_to_explode(
 320    expression: exp.Expression,
 321    unnest_using_arrays_zip: bool = True,
 322) -> exp.Expression:
 323    """Convert cross join unnest into lateral view explode."""
 324
 325    def _unnest_zip_exprs(
 326        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
 327    ) -> t.List[exp.Expression]:
 328        if has_multi_expr:
 329            if not unnest_using_arrays_zip:
 330                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
 331
 332            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
 333            zip_exprs: t.List[exp.Expression] = [
 334                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
 335            ]
 336            u.set("expressions", zip_exprs)
 337            return zip_exprs
 338        return unnest_exprs
 339
 340    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
 341        if u.args.get("offset"):
 342            return exp.Posexplode
 343        return exp.Inline if has_multi_expr else exp.Explode
 344
 345    if isinstance(expression, exp.Select):
 346        from_ = expression.args.get("from")
 347
 348        if from_ and isinstance(from_.this, exp.Unnest):
 349            unnest = from_.this
 350            alias = unnest.args.get("alias")
 351            exprs = unnest.expressions
 352            has_multi_expr = len(exprs) > 1
 353            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 354
 355            columns = alias.columns if alias else []
 356            offset = unnest.args.get("offset")
 357            if offset:
 358                columns.insert(
 359                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
 360                )
 361
 362            unnest.replace(
 363                exp.Table(
 364                    this=_udtf_type(unnest, has_multi_expr)(
 365                        this=this,
 366                        expressions=expressions,
 367                    ),
 368                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
 369                )
 370            )
 371
 372        joins = expression.args.get("joins") or []
 373        for join in list(joins):
 374            join_expr = join.this
 375
 376            is_lateral = isinstance(join_expr, exp.Lateral)
 377
 378            unnest = join_expr.this if is_lateral else join_expr
 379
 380            if isinstance(unnest, exp.Unnest):
 381                if is_lateral:
 382                    alias = join_expr.args.get("alias")
 383                else:
 384                    alias = unnest.args.get("alias")
 385                exprs = unnest.expressions
 386                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
 387                has_multi_expr = len(exprs) > 1
 388                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 389
 390                joins.remove(join)
 391
 392                alias_cols = alias.columns if alias else []
 393
 394                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
 395                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
 396                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
 397
 398                if not has_multi_expr and len(alias_cols) not in (1, 2):
 399                    raise UnsupportedError(
 400                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
 401                    )
 402
 403                offset = unnest.args.get("offset")
 404                if offset:
 405                    alias_cols.insert(
 406                        0,
 407                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
 408                    )
 409
 410                for e, column in zip(exprs, alias_cols):
 411                    expression.append(
 412                        "laterals",
 413                        exp.Lateral(
 414                            this=_udtf_type(unnest, has_multi_expr)(this=e),
 415                            view=True,
 416                            alias=exp.TableAlias(
 417                                this=alias.this,  # type: ignore
 418                                columns=alias_cols,
 419                            ),
 420                        ),
 421                    )
 422
 423    return expression
 424
 425
 426def explode_projection_to_unnest(
 427    index_offset: int = 0,
 428) -> t.Callable[[exp.Expression], exp.Expression]:
 429    """Convert explode/posexplode projections into unnests."""
 430
 431    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
 432        if isinstance(expression, exp.Select):
 433            from sqlglot.optimizer.scope import Scope
 434
 435            taken_select_names = set(expression.named_selects)
 436            taken_source_names = {name for name, _ in Scope(expression).references}
 437
 438            def new_name(names: t.Set[str], name: str) -> str:
 439                name = find_new_name(names, name)
 440                names.add(name)
 441                return name
 442
 443            arrays: t.List[exp.Condition] = []
 444            series_alias = new_name(taken_select_names, "pos")
 445            series = exp.alias_(
 446                exp.Unnest(
 447                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
 448                ),
 449                new_name(taken_source_names, "_u"),
 450                table=[series_alias],
 451            )
 452
 453            # we use list here because expression.selects is mutated inside the loop
 454            for select in list(expression.selects):
 455                explode = select.find(exp.Explode)
 456
 457                if explode:
 458                    pos_alias = ""
 459                    explode_alias = ""
 460
 461                    if isinstance(select, exp.Alias):
 462                        explode_alias = select.args["alias"]
 463                        alias = select
 464                    elif isinstance(select, exp.Aliases):
 465                        pos_alias = select.aliases[0]
 466                        explode_alias = select.aliases[1]
 467                        alias = select.replace(exp.alias_(select.this, "", copy=False))
 468                    else:
 469                        alias = select.replace(exp.alias_(select, ""))
 470                        explode = alias.find(exp.Explode)
 471                        assert explode
 472
 473                    is_posexplode = isinstance(explode, exp.Posexplode)
 474                    explode_arg = explode.this
 475
 476                    if isinstance(explode, exp.ExplodeOuter):
 477                        bracket = explode_arg[0]
 478                        bracket.set("safe", True)
 479                        bracket.set("offset", True)
 480                        explode_arg = exp.func(
 481                            "IF",
 482                            exp.func(
 483                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
 484                            ).eq(0),
 485                            exp.array(bracket, copy=False),
 486                            explode_arg,
 487                        )
 488
 489                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
 490                    if isinstance(explode_arg, exp.Column):
 491                        taken_select_names.add(explode_arg.output_name)
 492
 493                    unnest_source_alias = new_name(taken_source_names, "_u")
 494
 495                    if not explode_alias:
 496                        explode_alias = new_name(taken_select_names, "col")
 497
 498                        if is_posexplode:
 499                            pos_alias = new_name(taken_select_names, "pos")
 500
 501                    if not pos_alias:
 502                        pos_alias = new_name(taken_select_names, "pos")
 503
 504                    alias.set("alias", exp.to_identifier(explode_alias))
 505
 506                    series_table_alias = series.args["alias"].this
 507                    column = exp.If(
 508                        this=exp.column(series_alias, table=series_table_alias).eq(
 509                            exp.column(pos_alias, table=unnest_source_alias)
 510                        ),
 511                        true=exp.column(explode_alias, table=unnest_source_alias),
 512                    )
 513
 514                    explode.replace(column)
 515
 516                    if is_posexplode:
 517                        expressions = expression.expressions
 518                        expressions.insert(
 519                            expressions.index(alias) + 1,
 520                            exp.If(
 521                                this=exp.column(series_alias, table=series_table_alias).eq(
 522                                    exp.column(pos_alias, table=unnest_source_alias)
 523                                ),
 524                                true=exp.column(pos_alias, table=unnest_source_alias),
 525                            ).as_(pos_alias),
 526                        )
 527                        expression.set("expressions", expressions)
 528
 529                    if not arrays:
 530                        if expression.args.get("from"):
 531                            expression.join(series, copy=False, join_type="CROSS")
 532                        else:
 533                            expression.from_(series, copy=False)
 534
 535                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
 536                    arrays.append(size)
 537
 538                    # trino doesn't support left join unnest with on conditions
 539                    # if it did, this would be much simpler
 540                    expression.join(
 541                        exp.alias_(
 542                            exp.Unnest(
 543                                expressions=[explode_arg.copy()],
 544                                offset=exp.to_identifier(pos_alias),
 545                            ),
 546                            unnest_source_alias,
 547                            table=[explode_alias],
 548                        ),
 549                        join_type="CROSS",
 550                        copy=False,
 551                    )
 552
 553                    if index_offset != 1:
 554                        size = size - 1
 555
 556                    expression.where(
 557                        exp.column(series_alias, table=series_table_alias)
 558                        .eq(exp.column(pos_alias, table=unnest_source_alias))
 559                        .or_(
 560                            (exp.column(series_alias, table=series_table_alias) > size).and_(
 561                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
 562                            )
 563                        ),
 564                        copy=False,
 565                    )
 566
 567            if arrays:
 568                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
 569
 570                if index_offset != 1:
 571                    end = end - (1 - index_offset)
 572                series.expressions[0].set("end", end)
 573
 574        return expression
 575
 576    return _explode_projection_to_unnest
 577
 578
 579def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 580    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
 581    if (
 582        isinstance(expression, exp.PERCENTILES)
 583        and not isinstance(expression.parent, exp.WithinGroup)
 584        and expression.expression
 585    ):
 586        column = expression.this.pop()
 587        expression.set("this", expression.expression.pop())
 588        order = exp.Order(expressions=[exp.Ordered(this=column)])
 589        expression = exp.WithinGroup(this=expression, expression=order)
 590
 591    return expression
 592
 593
 594def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 595    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
 596    if (
 597        isinstance(expression, exp.WithinGroup)
 598        and isinstance(expression.this, exp.PERCENTILES)
 599        and isinstance(expression.expression, exp.Order)
 600    ):
 601        quantile = expression.this.this
 602        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
 603        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
 604
 605    return expression
 606
 607
 608def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
 609    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
 610    if isinstance(expression, exp.With) and expression.recursive:
 611        next_name = name_sequence("_c_")
 612
 613        for cte in expression.expressions:
 614            if not cte.args["alias"].columns:
 615                query = cte.this
 616                if isinstance(query, exp.SetOperation):
 617                    query = query.this
 618
 619                cte.args["alias"].set(
 620                    "columns",
 621                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
 622                )
 623
 624    return expression
 625
 626
 627def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
 628    """Replace 'epoch' in casts by the equivalent date literal."""
 629    if (
 630        isinstance(expression, (exp.Cast, exp.TryCast))
 631        and expression.name.lower() == "epoch"
 632        and expression.to.this in exp.DataType.TEMPORAL_TYPES
 633    ):
 634        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
 635
 636    return expression
 637
 638
 639def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
 640    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
 641    if isinstance(expression, exp.Select):
 642        for join in expression.args.get("joins") or []:
 643            on = join.args.get("on")
 644            if on and join.kind in ("SEMI", "ANTI"):
 645                subquery = exp.select("1").from_(join.this).where(on)
 646                exists = exp.Exists(this=subquery)
 647                if join.kind == "ANTI":
 648                    exists = exists.not_(copy=False)
 649
 650                join.pop()
 651                expression.where(exists, copy=False)
 652
 653    return expression
 654
 655
 656def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
 657    """
 658    Converts a query with a FULL OUTER join to a union of identical queries that
 659    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
 660    for queries that have a single FULL OUTER join.
 661    """
 662    if isinstance(expression, exp.Select):
 663        full_outer_joins = [
 664            (index, join)
 665            for index, join in enumerate(expression.args.get("joins") or [])
 666            if join.side == "FULL"
 667        ]
 668
 669        if len(full_outer_joins) == 1:
 670            expression_copy = expression.copy()
 671            expression.set("limit", None)
 672            index, full_outer_join = full_outer_joins[0]
 673
 674            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
 675            join_conditions = full_outer_join.args.get("on") or exp.and_(
 676                *[
 677                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
 678                    for col in full_outer_join.args.get("using")
 679                ]
 680            )
 681
 682            full_outer_join.set("side", "left")
 683            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
 684            expression_copy.args["joins"][index].set("side", "right")
 685            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
 686            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
 687            expression.args.pop("order", None)  # remove order by from LEFT side
 688
 689            return exp.union(expression, expression_copy, copy=False, distinct=False)
 690
 691    return expression
 692
 693
 694def move_ctes_to_top_level(expression: E) -> E:
 695    """
 696    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
 697    defined at the top-level, so for example queries like:
 698
 699        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
 700
 701    are invalid in those dialects. This transformation can be used to ensure all CTEs are
 702    moved to the top level so that the final SQL code is valid from a syntax standpoint.
 703
 704    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
 705    """
 706    top_level_with = expression.args.get("with")
 707    for inner_with in expression.find_all(exp.With):
 708        if inner_with.parent is expression:
 709            continue
 710
 711        if not top_level_with:
 712            top_level_with = inner_with.pop()
 713            expression.set("with", top_level_with)
 714        else:
 715            if inner_with.recursive:
 716                top_level_with.set("recursive", True)
 717
 718            parent_cte = inner_with.find_ancestor(exp.CTE)
 719            inner_with.pop()
 720
 721            if parent_cte:
 722                i = top_level_with.expressions.index(parent_cte)
 723                top_level_with.expressions[i:i] = inner_with.expressions
 724                top_level_with.set("expressions", top_level_with.expressions)
 725            else:
 726                top_level_with.set(
 727                    "expressions", top_level_with.expressions + inner_with.expressions
 728                )
 729
 730    return expression
 731
 732
 733def ensure_bools(expression: exp.Expression) -> exp.Expression:
 734    """Converts numeric values used in conditions into explicit boolean expressions."""
 735    from sqlglot.optimizer.canonicalize import ensure_bools
 736
 737    def _ensure_bool(node: exp.Expression) -> None:
 738        if (
 739            node.is_number
 740            or (
 741                not isinstance(node, exp.SubqueryPredicate)
 742                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
 743            )
 744            or (isinstance(node, exp.Column) and not node.type)
 745        ):
 746            node.replace(node.neq(0))
 747
 748    for node in expression.walk():
 749        ensure_bools(node, _ensure_bool)
 750
 751    return expression
 752
 753
 754def unqualify_columns(expression: exp.Expression) -> exp.Expression:
 755    for column in expression.find_all(exp.Column):
 756        # We only wanna pop off the table, db, catalog args
 757        for part in column.parts[:-1]:
 758            part.pop()
 759
 760    return expression
 761
 762
 763def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
 764    assert isinstance(expression, exp.Create)
 765    for constraint in expression.find_all(exp.UniqueColumnConstraint):
 766        if constraint.parent:
 767            constraint.parent.pop()
 768
 769    return expression
 770
 771
 772def ctas_with_tmp_tables_to_create_tmp_view(
 773    expression: exp.Expression,
 774    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
 775) -> exp.Expression:
 776    assert isinstance(expression, exp.Create)
 777    properties = expression.args.get("properties")
 778    temporary = any(
 779        isinstance(prop, exp.TemporaryProperty)
 780        for prop in (properties.expressions if properties else [])
 781    )
 782
 783    # CTAS with temp tables map to CREATE TEMPORARY VIEW
 784    if expression.kind == "TABLE" and temporary:
 785        if expression.expression:
 786            return exp.Create(
 787                kind="TEMPORARY VIEW",
 788                this=expression.this,
 789                expression=expression.expression,
 790            )
 791        return tmp_storage_provider(expression)
 792
 793    return expression
 794
 795
 796def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
 797    """
 798    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
 799    PARTITIONED BY value is an array of column names, they are transformed into a schema.
 800    The corresponding columns are removed from the create statement.
 801    """
 802    assert isinstance(expression, exp.Create)
 803    has_schema = isinstance(expression.this, exp.Schema)
 804    is_partitionable = expression.kind in {"TABLE", "VIEW"}
 805
 806    if has_schema and is_partitionable:
 807        prop = expression.find(exp.PartitionedByProperty)
 808        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 809            schema = expression.this
 810            columns = {v.name.upper() for v in prop.this.expressions}
 811            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 812            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 813            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 814            expression.set("this", schema)
 815
 816    return expression
 817
 818
 819def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
 820    """
 821    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
 822
 823    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
 824    """
 825    assert isinstance(expression, exp.Create)
 826    prop = expression.find(exp.PartitionedByProperty)
 827    if (
 828        prop
 829        and prop.this
 830        and isinstance(prop.this, exp.Schema)
 831        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
 832    ):
 833        prop_this = exp.Tuple(
 834            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
 835        )
 836        schema = expression.this
 837        for e in prop.this.expressions:
 838            schema.append("expressions", e)
 839        prop.set("this", prop_this)
 840
 841    return expression
 842
 843
 844def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
 845    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
 846    if isinstance(expression, exp.Struct):
 847        expression.set(
 848            "expressions",
 849            [
 850                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
 851                for e in expression.expressions
 852            ],
 853        )
 854
 855    return expression
 856
 857
 858def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
 859    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
 860
 861    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
 862
 863    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
 864
 865    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
 866
 867    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
 868
 869    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
 870
 871    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
 872
 873    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
 874
 875    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
 876
 877    -- example with WHERE
 878    SELECT d.department_name, sum(e.salary) as total_salary
 879    FROM departments d, employees e
 880    WHERE e.department_id(+) = d.department_id
 881    group by department_name
 882
 883    -- example of left correlation in select
 884    SELECT d.department_name, (
 885        SELECT SUM(e.salary)
 886            FROM employees e
 887            WHERE e.department_id(+) = d.department_id) AS total_salary
 888    FROM departments d;
 889
 890    -- example of left correlation in from
 891    SELECT d.department_name, t.total_salary
 892    FROM departments d, (
 893            SELECT SUM(e.salary) AS total_salary
 894            FROM employees e
 895            WHERE e.department_id(+) = d.department_id
 896        ) t
 897    """
 898
 899    from sqlglot.optimizer.scope import traverse_scope
 900    from sqlglot.optimizer.normalize import normalize, normalized
 901    from collections import defaultdict
 902
 903    # we go in reverse to check the main query for left correlation
 904    for scope in reversed(traverse_scope(expression)):
 905        query = scope.expression
 906
 907        where = query.args.get("where")
 908        joins = query.args.get("joins", [])
 909
 910        # knockout: we do not support left correlation (see point 2)
 911        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
 912
 913        # nothing to do - we check it here after knockout above
 914        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
 915            continue
 916
 917        # make sure we have AND of ORs to have clear join terms
 918        where = normalize(where.this)
 919        assert normalized(where), "Cannot normalize JOIN predicates"
 920
 921        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
 922        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
 923            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
 924
 925            left_join_table = set(col.table for col in join_cols)
 926            if not left_join_table:
 927                continue
 928
 929            assert not (
 930                len(left_join_table) > 1
 931            ), "Cannot combine JOIN predicates from different tables"
 932
 933            for col in join_cols:
 934                col.set("join_mark", False)
 935
 936            joins_ons[left_join_table.pop()].append(cond)
 937
 938        old_joins = {join.alias_or_name: join for join in joins}
 939        new_joins = {}
 940        query_from = query.args["from"]
 941
 942        for table, predicates in joins_ons.items():
 943            join_what = old_joins.get(table, query_from).this.copy()
 944            new_joins[join_what.alias_or_name] = exp.Join(
 945                this=join_what, on=exp.and_(*predicates), kind="LEFT"
 946            )
 947
 948            for p in predicates:
 949                while isinstance(p.parent, exp.Paren):
 950                    p.parent.replace(p)
 951
 952                parent = p.parent
 953                p.pop()
 954                if isinstance(parent, exp.Binary):
 955                    parent.replace(parent.right if parent.left is None else parent.left)
 956                elif isinstance(parent, exp.Where):
 957                    parent.pop()
 958
 959        if query_from.alias_or_name in new_joins:
 960            only_old_joins = old_joins.keys() - new_joins.keys()
 961            assert (
 962                len(only_old_joins) >= 1
 963            ), "Cannot determine which table to use in the new FROM clause"
 964
 965            new_from_name = list(only_old_joins)[0]
 966            query.set("from", exp.From(this=old_joins[new_from_name].this))
 967
 968        if new_joins:
 969            for n, j in old_joins.items():  # preserve any other joins
 970                if n not in new_joins and n != query.args["from"].name:
 971                    if not j.kind:
 972                        j.set("kind", "CROSS")
 973                    new_joins[n] = j
 974            query.set("joins", list(new_joins.values()))
 975
 976    return expression
 977
 978
 979def any_to_exists(expression: exp.Expression) -> exp.Expression:
 980    """
 981    Transform ANY operator to Spark's EXISTS
 982
 983    For example,
 984        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
 985        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
 986
 987    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
 988    transformation
 989    """
 990    if isinstance(expression, exp.Select):
 991        for any_expr in expression.find_all(exp.Any):
 992            this = any_expr.this
 993            if isinstance(this, exp.Query):
 994                continue
 995
 996            binop = any_expr.parent
 997            if isinstance(binop, exp.Binary):
 998                lambda_arg = exp.to_identifier("x")
 999                any_expr.replace(lambda_arg)
1000                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
1001                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
1002
1003    return expression
1004
1005
1006def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
1007    """Eliminates the `WINDOW` query clause by inling each named window."""
1008    if isinstance(expression, exp.Select) and expression.args.get("windows"):
1009        from sqlglot.optimizer.scope import find_all_in_scope
1010
1011        windows = expression.args["windows"]
1012        expression.set("windows", None)
1013
1014        window_expression: t.Dict[str, exp.Expression] = {}
1015
1016        def _inline_inherited_window(window: exp.Expression) -> None:
1017            inherited_window = window_expression.get(window.alias.lower())
1018            if not inherited_window:
1019                return
1020
1021            window.set("alias", None)
1022            for key in ("partition_by", "order", "spec"):
1023                arg = inherited_window.args.get(key)
1024                if arg:
1025                    window.set(key, arg.copy())
1026
1027        for window in windows:
1028            _inline_inherited_window(window)
1029            window_expression[window.name.lower()] = window
1030
1031        for window in find_all_in_scope(expression, exp.Window):
1032            _inline_inherited_window(window)
1033
1034    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
184    ):
185        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
186
187        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189
190        order = expression.args.get("order")
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number_window_alias)
197        expression.select(window, copy=False)
198
199        # We add aliases to the projections so that we can safely reference them in the outer query
200        new_selects = []
201        taken_names = {row_number_window_alias}
202        for select in expression.selects[:-1]:
203            if select.is_star:
204                new_selects = [exp.Star()]
205                break
206
207            if not isinstance(select, exp.Alias):
208                alias = find_new_name(taken_names, select.output_name or "_col")
209                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
210                select = select.replace(exp.alias_(select, alias, quoted=quoted))
211
212            taken_names.add(select.output_name)
213            new_selects.append(select.args["alias"])
214
215        return (
216            exp.select(*new_selects, copy=False)
217            .from_(expression.subquery("_t", copy=False), copy=False)
218            .where(exp.column(row_number_window_alias).eq(1), copy=False)
219        )
220
221    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
224def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
225    """
226    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
227
228    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
229    https://docs.snowflake.com/en/sql-reference/constructs/qualify
230
231    Some dialects don't support window functions in the WHERE clause, so we need to include them as
232    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
233    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
234    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
235    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
236    corresponding expression to avoid creating invalid column references.
237    """
238    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
239        taken = set(expression.named_selects)
240        for select in expression.selects:
241            if not select.alias_or_name:
242                alias = find_new_name(taken, "_c")
243                select.replace(exp.alias_(select, alias))
244                taken.add(alias)
245
246        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
247            alias_or_name = select.alias_or_name
248            identifier = select.args.get("alias") or select.this
249            if isinstance(identifier, exp.Identifier):
250                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
251            return alias_or_name
252
253        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
254        qualify_filters = expression.args["qualify"].pop().this
255        expression_by_alias = {
256            select.alias: select.this
257            for select in expression.selects
258            if isinstance(select, exp.Alias)
259        }
260
261        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
262        for select_candidate in list(qualify_filters.find_all(select_candidates)):
263            if isinstance(select_candidate, exp.Window):
264                if expression_by_alias:
265                    for column in select_candidate.find_all(exp.Column):
266                        expr = expression_by_alias.get(column.name)
267                        if expr:
268                            column.replace(expr)
269
270                alias = find_new_name(expression.named_selects, "_w")
271                expression.select(exp.alias_(select_candidate, alias), copy=False)
272                column = exp.column(alias)
273
274                if isinstance(select_candidate.parent, exp.Qualify):
275                    qualify_filters = column
276                else:
277                    select_candidate.replace(column)
278            elif select_candidate.name not in expression.named_selects:
279                expression.select(select_candidate.copy(), copy=False)
280
281        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
282            qualify_filters, copy=False
283        )
284
285    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
288def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
289    """
290    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
291    other expressions. This transforms removes the precision from parameterized types in expressions.
292    """
293    for node in expression.find_all(exp.DataType):
294        node.set(
295            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
296        )
297
298    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
301def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
302    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
303    from sqlglot.optimizer.scope import find_all_in_scope
304
305    if isinstance(expression, exp.Select):
306        unnest_aliases = {
307            unnest.alias
308            for unnest in find_all_in_scope(expression, exp.Unnest)
309            if isinstance(unnest.parent, (exp.From, exp.Join))
310        }
311        if unnest_aliases:
312            for column in expression.find_all(exp.Column):
313                leftmost_part = column.parts[0]
314                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
315                    leftmost_part.pop()
316
317    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
320def unnest_to_explode(
321    expression: exp.Expression,
322    unnest_using_arrays_zip: bool = True,
323) -> exp.Expression:
324    """Convert cross join unnest into lateral view explode."""
325
326    def _unnest_zip_exprs(
327        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
328    ) -> t.List[exp.Expression]:
329        if has_multi_expr:
330            if not unnest_using_arrays_zip:
331                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
332
333            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
334            zip_exprs: t.List[exp.Expression] = [
335                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
336            ]
337            u.set("expressions", zip_exprs)
338            return zip_exprs
339        return unnest_exprs
340
341    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
342        if u.args.get("offset"):
343            return exp.Posexplode
344        return exp.Inline if has_multi_expr else exp.Explode
345
346    if isinstance(expression, exp.Select):
347        from_ = expression.args.get("from")
348
349        if from_ and isinstance(from_.this, exp.Unnest):
350            unnest = from_.this
351            alias = unnest.args.get("alias")
352            exprs = unnest.expressions
353            has_multi_expr = len(exprs) > 1
354            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
355
356            columns = alias.columns if alias else []
357            offset = unnest.args.get("offset")
358            if offset:
359                columns.insert(
360                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
361                )
362
363            unnest.replace(
364                exp.Table(
365                    this=_udtf_type(unnest, has_multi_expr)(
366                        this=this,
367                        expressions=expressions,
368                    ),
369                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
370                )
371            )
372
373        joins = expression.args.get("joins") or []
374        for join in list(joins):
375            join_expr = join.this
376
377            is_lateral = isinstance(join_expr, exp.Lateral)
378
379            unnest = join_expr.this if is_lateral else join_expr
380
381            if isinstance(unnest, exp.Unnest):
382                if is_lateral:
383                    alias = join_expr.args.get("alias")
384                else:
385                    alias = unnest.args.get("alias")
386                exprs = unnest.expressions
387                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
388                has_multi_expr = len(exprs) > 1
389                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
390
391                joins.remove(join)
392
393                alias_cols = alias.columns if alias else []
394
395                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
396                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
397                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
398
399                if not has_multi_expr and len(alias_cols) not in (1, 2):
400                    raise UnsupportedError(
401                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
402                    )
403
404                offset = unnest.args.get("offset")
405                if offset:
406                    alias_cols.insert(
407                        0,
408                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
409                    )
410
411                for e, column in zip(exprs, alias_cols):
412                    expression.append(
413                        "laterals",
414                        exp.Lateral(
415                            this=_udtf_type(unnest, has_multi_expr)(this=e),
416                            view=True,
417                            alias=exp.TableAlias(
418                                this=alias.this,  # type: ignore
419                                columns=alias_cols,
420                            ),
421                        ),
422                    )
423
424    return expression

Convert cross join unnest into lateral view explode.

def explode_projection_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
427def explode_projection_to_unnest(
428    index_offset: int = 0,
429) -> t.Callable[[exp.Expression], exp.Expression]:
430    """Convert explode/posexplode projections into unnests."""
431
432    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
433        if isinstance(expression, exp.Select):
434            from sqlglot.optimizer.scope import Scope
435
436            taken_select_names = set(expression.named_selects)
437            taken_source_names = {name for name, _ in Scope(expression).references}
438
439            def new_name(names: t.Set[str], name: str) -> str:
440                name = find_new_name(names, name)
441                names.add(name)
442                return name
443
444            arrays: t.List[exp.Condition] = []
445            series_alias = new_name(taken_select_names, "pos")
446            series = exp.alias_(
447                exp.Unnest(
448                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
449                ),
450                new_name(taken_source_names, "_u"),
451                table=[series_alias],
452            )
453
454            # we use list here because expression.selects is mutated inside the loop
455            for select in list(expression.selects):
456                explode = select.find(exp.Explode)
457
458                if explode:
459                    pos_alias = ""
460                    explode_alias = ""
461
462                    if isinstance(select, exp.Alias):
463                        explode_alias = select.args["alias"]
464                        alias = select
465                    elif isinstance(select, exp.Aliases):
466                        pos_alias = select.aliases[0]
467                        explode_alias = select.aliases[1]
468                        alias = select.replace(exp.alias_(select.this, "", copy=False))
469                    else:
470                        alias = select.replace(exp.alias_(select, ""))
471                        explode = alias.find(exp.Explode)
472                        assert explode
473
474                    is_posexplode = isinstance(explode, exp.Posexplode)
475                    explode_arg = explode.this
476
477                    if isinstance(explode, exp.ExplodeOuter):
478                        bracket = explode_arg[0]
479                        bracket.set("safe", True)
480                        bracket.set("offset", True)
481                        explode_arg = exp.func(
482                            "IF",
483                            exp.func(
484                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
485                            ).eq(0),
486                            exp.array(bracket, copy=False),
487                            explode_arg,
488                        )
489
490                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
491                    if isinstance(explode_arg, exp.Column):
492                        taken_select_names.add(explode_arg.output_name)
493
494                    unnest_source_alias = new_name(taken_source_names, "_u")
495
496                    if not explode_alias:
497                        explode_alias = new_name(taken_select_names, "col")
498
499                        if is_posexplode:
500                            pos_alias = new_name(taken_select_names, "pos")
501
502                    if not pos_alias:
503                        pos_alias = new_name(taken_select_names, "pos")
504
505                    alias.set("alias", exp.to_identifier(explode_alias))
506
507                    series_table_alias = series.args["alias"].this
508                    column = exp.If(
509                        this=exp.column(series_alias, table=series_table_alias).eq(
510                            exp.column(pos_alias, table=unnest_source_alias)
511                        ),
512                        true=exp.column(explode_alias, table=unnest_source_alias),
513                    )
514
515                    explode.replace(column)
516
517                    if is_posexplode:
518                        expressions = expression.expressions
519                        expressions.insert(
520                            expressions.index(alias) + 1,
521                            exp.If(
522                                this=exp.column(series_alias, table=series_table_alias).eq(
523                                    exp.column(pos_alias, table=unnest_source_alias)
524                                ),
525                                true=exp.column(pos_alias, table=unnest_source_alias),
526                            ).as_(pos_alias),
527                        )
528                        expression.set("expressions", expressions)
529
530                    if not arrays:
531                        if expression.args.get("from"):
532                            expression.join(series, copy=False, join_type="CROSS")
533                        else:
534                            expression.from_(series, copy=False)
535
536                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
537                    arrays.append(size)
538
539                    # trino doesn't support left join unnest with on conditions
540                    # if it did, this would be much simpler
541                    expression.join(
542                        exp.alias_(
543                            exp.Unnest(
544                                expressions=[explode_arg.copy()],
545                                offset=exp.to_identifier(pos_alias),
546                            ),
547                            unnest_source_alias,
548                            table=[explode_alias],
549                        ),
550                        join_type="CROSS",
551                        copy=False,
552                    )
553
554                    if index_offset != 1:
555                        size = size - 1
556
557                    expression.where(
558                        exp.column(series_alias, table=series_table_alias)
559                        .eq(exp.column(pos_alias, table=unnest_source_alias))
560                        .or_(
561                            (exp.column(series_alias, table=series_table_alias) > size).and_(
562                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
563                            )
564                        ),
565                        copy=False,
566                    )
567
568            if arrays:
569                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
570
571                if index_offset != 1:
572                    end = end - (1 - index_offset)
573                series.expressions[0].set("end", end)
574
575        return expression
576
577    return _explode_projection_to_unnest

Convert explode/posexplode projections into unnests.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
580def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
581    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
582    if (
583        isinstance(expression, exp.PERCENTILES)
584        and not isinstance(expression.parent, exp.WithinGroup)
585        and expression.expression
586    ):
587        column = expression.this.pop()
588        expression.set("this", expression.expression.pop())
589        order = exp.Order(expressions=[exp.Ordered(this=column)])
590        expression = exp.WithinGroup(this=expression, expression=order)
591
592    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
595def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
596    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
597    if (
598        isinstance(expression, exp.WithinGroup)
599        and isinstance(expression.this, exp.PERCENTILES)
600        and isinstance(expression.expression, exp.Order)
601    ):
602        quantile = expression.this.this
603        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
604        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
605
606    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
609def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
610    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
611    if isinstance(expression, exp.With) and expression.recursive:
612        next_name = name_sequence("_c_")
613
614        for cte in expression.expressions:
615            if not cte.args["alias"].columns:
616                query = cte.this
617                if isinstance(query, exp.SetOperation):
618                    query = query.this
619
620                cte.args["alias"].set(
621                    "columns",
622                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
623                )
624
625    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
628def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
629    """Replace 'epoch' in casts by the equivalent date literal."""
630    if (
631        isinstance(expression, (exp.Cast, exp.TryCast))
632        and expression.name.lower() == "epoch"
633        and expression.to.this in exp.DataType.TEMPORAL_TYPES
634    ):
635        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
636
637    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
640def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
641    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
642    if isinstance(expression, exp.Select):
643        for join in expression.args.get("joins") or []:
644            on = join.args.get("on")
645            if on and join.kind in ("SEMI", "ANTI"):
646                subquery = exp.select("1").from_(join.this).where(on)
647                exists = exp.Exists(this=subquery)
648                if join.kind == "ANTI":
649                    exists = exists.not_(copy=False)
650
651                join.pop()
652                expression.where(exists, copy=False)
653
654    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
657def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
658    """
659    Converts a query with a FULL OUTER join to a union of identical queries that
660    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
661    for queries that have a single FULL OUTER join.
662    """
663    if isinstance(expression, exp.Select):
664        full_outer_joins = [
665            (index, join)
666            for index, join in enumerate(expression.args.get("joins") or [])
667            if join.side == "FULL"
668        ]
669
670        if len(full_outer_joins) == 1:
671            expression_copy = expression.copy()
672            expression.set("limit", None)
673            index, full_outer_join = full_outer_joins[0]
674
675            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
676            join_conditions = full_outer_join.args.get("on") or exp.and_(
677                *[
678                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
679                    for col in full_outer_join.args.get("using")
680                ]
681            )
682
683            full_outer_join.set("side", "left")
684            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
685            expression_copy.args["joins"][index].set("side", "right")
686            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
687            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
688            expression.args.pop("order", None)  # remove order by from LEFT side
689
690            return exp.union(expression, expression_copy, copy=False, distinct=False)
691
692    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
695def move_ctes_to_top_level(expression: E) -> E:
696    """
697    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
698    defined at the top-level, so for example queries like:
699
700        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
701
702    are invalid in those dialects. This transformation can be used to ensure all CTEs are
703    moved to the top level so that the final SQL code is valid from a syntax standpoint.
704
705    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
706    """
707    top_level_with = expression.args.get("with")
708    for inner_with in expression.find_all(exp.With):
709        if inner_with.parent is expression:
710            continue
711
712        if not top_level_with:
713            top_level_with = inner_with.pop()
714            expression.set("with", top_level_with)
715        else:
716            if inner_with.recursive:
717                top_level_with.set("recursive", True)
718
719            parent_cte = inner_with.find_ancestor(exp.CTE)
720            inner_with.pop()
721
722            if parent_cte:
723                i = top_level_with.expressions.index(parent_cte)
724                top_level_with.expressions[i:i] = inner_with.expressions
725                top_level_with.set("expressions", top_level_with.expressions)
726            else:
727                top_level_with.set(
728                    "expressions", top_level_with.expressions + inner_with.expressions
729                )
730
731    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
734def ensure_bools(expression: exp.Expression) -> exp.Expression:
735    """Converts numeric values used in conditions into explicit boolean expressions."""
736    from sqlglot.optimizer.canonicalize import ensure_bools
737
738    def _ensure_bool(node: exp.Expression) -> None:
739        if (
740            node.is_number
741            or (
742                not isinstance(node, exp.SubqueryPredicate)
743                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
744            )
745            or (isinstance(node, exp.Column) and not node.type)
746        ):
747            node.replace(node.neq(0))
748
749    for node in expression.walk():
750        ensure_bools(node, _ensure_bool)
751
752    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
755def unqualify_columns(expression: exp.Expression) -> exp.Expression:
756    for column in expression.find_all(exp.Column):
757        # We only wanna pop off the table, db, catalog args
758        for part in column.parts[:-1]:
759            part.pop()
760
761    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
764def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
765    assert isinstance(expression, exp.Create)
766    for constraint in expression.find_all(exp.UniqueColumnConstraint):
767        if constraint.parent:
768            constraint.parent.pop()
769
770    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
773def ctas_with_tmp_tables_to_create_tmp_view(
774    expression: exp.Expression,
775    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
776) -> exp.Expression:
777    assert isinstance(expression, exp.Create)
778    properties = expression.args.get("properties")
779    temporary = any(
780        isinstance(prop, exp.TemporaryProperty)
781        for prop in (properties.expressions if properties else [])
782    )
783
784    # CTAS with temp tables map to CREATE TEMPORARY VIEW
785    if expression.kind == "TABLE" and temporary:
786        if expression.expression:
787            return exp.Create(
788                kind="TEMPORARY VIEW",
789                this=expression.this,
790                expression=expression.expression,
791            )
792        return tmp_storage_provider(expression)
793
794    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
797def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
798    """
799    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
800    PARTITIONED BY value is an array of column names, they are transformed into a schema.
801    The corresponding columns are removed from the create statement.
802    """
803    assert isinstance(expression, exp.Create)
804    has_schema = isinstance(expression.this, exp.Schema)
805    is_partitionable = expression.kind in {"TABLE", "VIEW"}
806
807    if has_schema and is_partitionable:
808        prop = expression.find(exp.PartitionedByProperty)
809        if prop and prop.this and not isinstance(prop.this, exp.Schema):
810            schema = expression.this
811            columns = {v.name.upper() for v in prop.this.expressions}
812            partitions = [col for col in schema.expressions if col.name.upper() in columns]
813            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
814            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
815            expression.set("this", schema)
816
817    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
820def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
821    """
822    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
823
824    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
825    """
826    assert isinstance(expression, exp.Create)
827    prop = expression.find(exp.PartitionedByProperty)
828    if (
829        prop
830        and prop.this
831        and isinstance(prop.this, exp.Schema)
832        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
833    ):
834        prop_this = exp.Tuple(
835            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
836        )
837        schema = expression.this
838        for e in prop.this.expressions:
839            schema.append("expressions", e)
840        prop.set("this", prop_this)
841
842    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
845def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
846    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
847    if isinstance(expression, exp.Struct):
848        expression.set(
849            "expressions",
850            [
851                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
852                for e in expression.expressions
853            ],
854        )
855
856    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
859def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
860    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
861
862    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
863
864    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
865
866    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
867
868    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
869
870    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
871
872    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
873
874    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
875
876    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
877
878    -- example with WHERE
879    SELECT d.department_name, sum(e.salary) as total_salary
880    FROM departments d, employees e
881    WHERE e.department_id(+) = d.department_id
882    group by department_name
883
884    -- example of left correlation in select
885    SELECT d.department_name, (
886        SELECT SUM(e.salary)
887            FROM employees e
888            WHERE e.department_id(+) = d.department_id) AS total_salary
889    FROM departments d;
890
891    -- example of left correlation in from
892    SELECT d.department_name, t.total_salary
893    FROM departments d, (
894            SELECT SUM(e.salary) AS total_salary
895            FROM employees e
896            WHERE e.department_id(+) = d.department_id
897        ) t
898    """
899
900    from sqlglot.optimizer.scope import traverse_scope
901    from sqlglot.optimizer.normalize import normalize, normalized
902    from collections import defaultdict
903
904    # we go in reverse to check the main query for left correlation
905    for scope in reversed(traverse_scope(expression)):
906        query = scope.expression
907
908        where = query.args.get("where")
909        joins = query.args.get("joins", [])
910
911        # knockout: we do not support left correlation (see point 2)
912        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
913
914        # nothing to do - we check it here after knockout above
915        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
916            continue
917
918        # make sure we have AND of ORs to have clear join terms
919        where = normalize(where.this)
920        assert normalized(where), "Cannot normalize JOIN predicates"
921
922        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
923        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
924            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
925
926            left_join_table = set(col.table for col in join_cols)
927            if not left_join_table:
928                continue
929
930            assert not (
931                len(left_join_table) > 1
932            ), "Cannot combine JOIN predicates from different tables"
933
934            for col in join_cols:
935                col.set("join_mark", False)
936
937            joins_ons[left_join_table.pop()].append(cond)
938
939        old_joins = {join.alias_or_name: join for join in joins}
940        new_joins = {}
941        query_from = query.args["from"]
942
943        for table, predicates in joins_ons.items():
944            join_what = old_joins.get(table, query_from).this.copy()
945            new_joins[join_what.alias_or_name] = exp.Join(
946                this=join_what, on=exp.and_(*predicates), kind="LEFT"
947            )
948
949            for p in predicates:
950                while isinstance(p.parent, exp.Paren):
951                    p.parent.replace(p)
952
953                parent = p.parent
954                p.pop()
955                if isinstance(parent, exp.Binary):
956                    parent.replace(parent.right if parent.left is None else parent.left)
957                elif isinstance(parent, exp.Where):
958                    parent.pop()
959
960        if query_from.alias_or_name in new_joins:
961            only_old_joins = old_joins.keys() - new_joins.keys()
962            assert (
963                len(only_old_joins) >= 1
964            ), "Cannot determine which table to use in the new FROM clause"
965
966            new_from_name = list(only_old_joins)[0]
967            query.set("from", exp.From(this=old_joins[new_from_name].this))
968
969        if new_joins:
970            for n, j in old_joins.items():  # preserve any other joins
971                if n not in new_joins and n != query.args["from"].name:
972                    if not j.kind:
973                        j.set("kind", "CROSS")
974                    new_joins[n] = j
975            query.set("joins", list(new_joins.values()))
976
977    return expression

https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178

  1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.

  2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.

The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.

You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.

The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.

A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.

A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.

A WHERE condition cannot compare any column marked with the (+) operator with a subquery.

-- example with WHERE SELECT d.department_name, sum(e.salary) as total_salary FROM departments d, employees e WHERE e.department_id(+) = d.department_id group by department_name

-- example of left correlation in select SELECT d.department_name, ( SELECT SUM(e.salary) FROM employees e WHERE e.department_id(+) = d.department_id) AS total_salary FROM departments d;

-- example of left correlation in from SELECT d.department_name, t.total_salary FROM departments d, ( SELECT SUM(e.salary) AS total_salary FROM employees e WHERE e.department_id(+) = d.department_id ) t

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 980def any_to_exists(expression: exp.Expression) -> exp.Expression:
 981    """
 982    Transform ANY operator to Spark's EXISTS
 983
 984    For example,
 985        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
 986        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
 987
 988    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
 989    transformation
 990    """
 991    if isinstance(expression, exp.Select):
 992        for any_expr in expression.find_all(exp.Any):
 993            this = any_expr.this
 994            if isinstance(this, exp.Query):
 995                continue
 996
 997            binop = any_expr.parent
 998            if isinstance(binop, exp.Binary):
 999                lambda_arg = exp.to_identifier("x")
1000                any_expr.replace(lambda_arg)
1001                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
1002                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
1003
1004    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation

def eliminate_window_clause( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1007def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
1008    """Eliminates the `WINDOW` query clause by inling each named window."""
1009    if isinstance(expression, exp.Select) and expression.args.get("windows"):
1010        from sqlglot.optimizer.scope import find_all_in_scope
1011
1012        windows = expression.args["windows"]
1013        expression.set("windows", None)
1014
1015        window_expression: t.Dict[str, exp.Expression] = {}
1016
1017        def _inline_inherited_window(window: exp.Expression) -> None:
1018            inherited_window = window_expression.get(window.alias.lower())
1019            if not inherited_window:
1020                return
1021
1022            window.set("alias", None)
1023            for key in ("partition_by", "order", "spec"):
1024                arg = inherited_window.args.get(key)
1025                if arg:
1026                    window.set(key, arg.copy())
1027
1028        for window in windows:
1029            _inline_inherited_window(window)
1030            window_expression[window.name.lower()] = window
1031
1032        for window in find_all_in_scope(expression, exp.Window):
1033            _inline_inherited_window(window)
1034
1035    return expression

Eliminates the WINDOW query clause by inling each named window.