Edit on GitHub

sqlglot.transforms

   1from __future__ import annotations
   2
   3import typing as t
   4
   5from sqlglot import expressions as exp
   6from sqlglot.errors import UnsupportedError
   7from sqlglot.helper import find_new_name, name_sequence
   8
   9
  10if t.TYPE_CHECKING:
  11    from sqlglot._typing import E
  12    from sqlglot.generator import Generator
  13
  14
  15def preprocess(
  16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
  17) -> t.Callable[[Generator, exp.Expression], str]:
  18    """
  19    Creates a new transform by chaining a sequence of transformations and converts the resulting
  20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
  21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
  22
  23    Args:
  24        transforms: sequence of transform functions. These will be called in order.
  25
  26    Returns:
  27        Function that can be used as a generator transform.
  28    """
  29
  30    def _to_sql(self, expression: exp.Expression) -> str:
  31        expression_type = type(expression)
  32
  33        try:
  34            expression = transforms[0](expression)
  35            for transform in transforms[1:]:
  36                expression = transform(expression)
  37        except UnsupportedError as unsupported_error:
  38            self.unsupported(str(unsupported_error))
  39
  40        _sql_handler = getattr(self, expression.key + "_sql", None)
  41        if _sql_handler:
  42            return _sql_handler(expression)
  43
  44        transforms_handler = self.TRANSFORMS.get(type(expression))
  45        if transforms_handler:
  46            if expression_type is type(expression):
  47                if isinstance(expression, exp.Func):
  48                    return self.function_fallback_sql(expression)
  49
  50                # Ensures we don't enter an infinite loop. This can happen when the original expression
  51                # has the same type as the final expression and there's no _sql method available for it,
  52                # because then it'd re-enter _to_sql.
  53                raise ValueError(
  54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
  55                )
  56
  57            return transforms_handler(self, expression)
  58
  59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
  60
  61    return _to_sql
  62
  63
  64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
  65    if isinstance(expression, exp.Select):
  66        count = 0
  67        recursive_ctes = []
  68
  69        for unnest in expression.find_all(exp.Unnest):
  70            if (
  71                not isinstance(unnest.parent, (exp.From, exp.Join))
  72                or len(unnest.expressions) != 1
  73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
  74            ):
  75                continue
  76
  77            generate_date_array = unnest.expressions[0]
  78            start = generate_date_array.args.get("start")
  79            end = generate_date_array.args.get("end")
  80            step = generate_date_array.args.get("step")
  81
  82            if not start or not end or not isinstance(step, exp.Interval):
  83                continue
  84
  85            alias = unnest.args.get("alias")
  86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
  87
  88            start = exp.cast(start, "date")
  89            date_add = exp.func(
  90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
  91            )
  92            cast_date_add = exp.cast(date_add, "date")
  93
  94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
  95
  96            base_query = exp.select(start.as_(column_name))
  97            recursive_query = (
  98                exp.select(cast_date_add)
  99                .from_(cte_name)
 100                .where(cast_date_add <= exp.cast(end, "date"))
 101            )
 102            cte_query = base_query.union(recursive_query, distinct=False)
 103
 104            generate_dates_query = exp.select(column_name).from_(cte_name)
 105            unnest.replace(generate_dates_query.subquery(cte_name))
 106
 107            recursive_ctes.append(
 108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
 109            )
 110            count += 1
 111
 112        if recursive_ctes:
 113            with_expression = expression.args.get("with") or exp.With()
 114            with_expression.set("recursive", True)
 115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
 116            expression.set("with", with_expression)
 117
 118    return expression
 119
 120
 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
 122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
 123    this = expression.this
 124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
 125        unnest = exp.Unnest(expressions=[this])
 126        if expression.alias:
 127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
 128
 129        return unnest
 130
 131    return expression
 132
 133
 134def unalias_group(expression: exp.Expression) -> exp.Expression:
 135    """
 136    Replace references to select aliases in GROUP BY clauses.
 137
 138    Example:
 139        >>> import sqlglot
 140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 141        'SELECT a AS b FROM x GROUP BY 1'
 142
 143    Args:
 144        expression: the expression that will be transformed.
 145
 146    Returns:
 147        The transformed expression.
 148    """
 149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 150        aliased_selects = {
 151            e.alias: i
 152            for i, e in enumerate(expression.parent.expressions, start=1)
 153            if isinstance(e, exp.Alias)
 154        }
 155
 156        for group_by in expression.expressions:
 157            if (
 158                isinstance(group_by, exp.Column)
 159                and not group_by.table
 160                and group_by.name in aliased_selects
 161            ):
 162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 163
 164    return expression
 165
 166
 167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 168    """
 169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 170
 171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 172
 173    Args:
 174        expression: the expression that will be transformed.
 175
 176    Returns:
 177        The transformed expression.
 178    """
 179    if (
 180        isinstance(expression, exp.Select)
 181        and expression.args.get("distinct")
 182        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
 183    ):
 184        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
 185
 186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 187        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 188
 189        order = expression.args.get("order")
 190        if order:
 191            window.set("order", order.pop())
 192        else:
 193            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 194
 195        window = exp.alias_(window, row_number_window_alias)
 196        expression.select(window, copy=False)
 197
 198        # We add aliases to the projections so that we can safely reference them in the outer query
 199        new_selects = []
 200        taken_names = {row_number_window_alias}
 201        for select in expression.selects[:-1]:
 202            if select.is_star:
 203                new_selects = [exp.Star()]
 204                break
 205
 206            if not isinstance(select, exp.Alias):
 207                alias = find_new_name(taken_names, select.output_name or "_col")
 208                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
 209                select = select.replace(exp.alias_(select, alias, quoted=quoted))
 210
 211            taken_names.add(select.output_name)
 212            new_selects.append(select.args["alias"])
 213
 214        return (
 215            exp.select(*new_selects, copy=False)
 216            .from_(expression.subquery("_t", copy=False), copy=False)
 217            .where(exp.column(row_number_window_alias).eq(1), copy=False)
 218        )
 219
 220    return expression
 221
 222
 223def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 224    """
 225    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 226
 227    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 228    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 229
 230    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 231    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 232    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 233    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 234    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 235    corresponding expression to avoid creating invalid column references.
 236    """
 237    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 238        taken = set(expression.named_selects)
 239        for select in expression.selects:
 240            if not select.alias_or_name:
 241                alias = find_new_name(taken, "_c")
 242                select.replace(exp.alias_(select, alias))
 243                taken.add(alias)
 244
 245        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
 246            alias_or_name = select.alias_or_name
 247            identifier = select.args.get("alias") or select.this
 248            if isinstance(identifier, exp.Identifier):
 249                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
 250            return alias_or_name
 251
 252        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
 253        qualify_filters = expression.args["qualify"].pop().this
 254        expression_by_alias = {
 255            select.alias: select.this
 256            for select in expression.selects
 257            if isinstance(select, exp.Alias)
 258        }
 259
 260        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
 261        for select_candidate in list(qualify_filters.find_all(select_candidates)):
 262            if isinstance(select_candidate, exp.Window):
 263                if expression_by_alias:
 264                    for column in select_candidate.find_all(exp.Column):
 265                        expr = expression_by_alias.get(column.name)
 266                        if expr:
 267                            column.replace(expr)
 268
 269                alias = find_new_name(expression.named_selects, "_w")
 270                expression.select(exp.alias_(select_candidate, alias), copy=False)
 271                column = exp.column(alias)
 272
 273                if isinstance(select_candidate.parent, exp.Qualify):
 274                    qualify_filters = column
 275                else:
 276                    select_candidate.replace(column)
 277            elif select_candidate.name not in expression.named_selects:
 278                expression.select(select_candidate.copy(), copy=False)
 279
 280        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
 281            qualify_filters, copy=False
 282        )
 283
 284    return expression
 285
 286
 287def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
 288    """
 289    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
 290    other expressions. This transforms removes the precision from parameterized types in expressions.
 291    """
 292    for node in expression.find_all(exp.DataType):
 293        node.set(
 294            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
 295        )
 296
 297    return expression
 298
 299
 300def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
 301    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
 302    from sqlglot.optimizer.scope import find_all_in_scope
 303
 304    if isinstance(expression, exp.Select):
 305        unnest_aliases = {
 306            unnest.alias
 307            for unnest in find_all_in_scope(expression, exp.Unnest)
 308            if isinstance(unnest.parent, (exp.From, exp.Join))
 309        }
 310        if unnest_aliases:
 311            for column in expression.find_all(exp.Column):
 312                leftmost_part = column.parts[0]
 313                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
 314                    leftmost_part.pop()
 315
 316    return expression
 317
 318
 319def unnest_to_explode(
 320    expression: exp.Expression,
 321    unnest_using_arrays_zip: bool = True,
 322) -> exp.Expression:
 323    """Convert cross join unnest into lateral view explode."""
 324
 325    def _unnest_zip_exprs(
 326        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
 327    ) -> t.List[exp.Expression]:
 328        if has_multi_expr:
 329            if not unnest_using_arrays_zip:
 330                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
 331
 332            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
 333            zip_exprs: t.List[exp.Expression] = [
 334                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
 335            ]
 336            u.set("expressions", zip_exprs)
 337            return zip_exprs
 338        return unnest_exprs
 339
 340    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
 341        if u.args.get("offset"):
 342            return exp.Posexplode
 343        return exp.Inline if has_multi_expr else exp.Explode
 344
 345    if isinstance(expression, exp.Select):
 346        from_ = expression.args.get("from")
 347
 348        if from_ and isinstance(from_.this, exp.Unnest):
 349            unnest = from_.this
 350            alias = unnest.args.get("alias")
 351            exprs = unnest.expressions
 352            has_multi_expr = len(exprs) > 1
 353            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 354
 355            unnest.replace(
 356                exp.Table(
 357                    this=_udtf_type(unnest, has_multi_expr)(
 358                        this=this,
 359                        expressions=expressions,
 360                    ),
 361                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
 362                )
 363            )
 364
 365        joins = expression.args.get("joins") or []
 366        for join in list(joins):
 367            join_expr = join.this
 368
 369            is_lateral = isinstance(join_expr, exp.Lateral)
 370
 371            unnest = join_expr.this if is_lateral else join_expr
 372
 373            if isinstance(unnest, exp.Unnest):
 374                if is_lateral:
 375                    alias = join_expr.args.get("alias")
 376                else:
 377                    alias = unnest.args.get("alias")
 378                exprs = unnest.expressions
 379                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
 380                has_multi_expr = len(exprs) > 1
 381                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 382
 383                joins.remove(join)
 384
 385                alias_cols = alias.columns if alias else []
 386
 387                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
 388                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
 389                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
 390
 391                if not has_multi_expr and len(alias_cols) not in (1, 2):
 392                    raise UnsupportedError(
 393                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
 394                    )
 395
 396                for e, column in zip(exprs, alias_cols):
 397                    expression.append(
 398                        "laterals",
 399                        exp.Lateral(
 400                            this=_udtf_type(unnest, has_multi_expr)(this=e),
 401                            view=True,
 402                            alias=exp.TableAlias(
 403                                this=alias.this,  # type: ignore
 404                                columns=alias_cols,
 405                            ),
 406                        ),
 407                    )
 408
 409    return expression
 410
 411
 412def explode_projection_to_unnest(
 413    index_offset: int = 0,
 414) -> t.Callable[[exp.Expression], exp.Expression]:
 415    """Convert explode/posexplode projections into unnests."""
 416
 417    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
 418        if isinstance(expression, exp.Select):
 419            from sqlglot.optimizer.scope import Scope
 420
 421            taken_select_names = set(expression.named_selects)
 422            taken_source_names = {name for name, _ in Scope(expression).references}
 423
 424            def new_name(names: t.Set[str], name: str) -> str:
 425                name = find_new_name(names, name)
 426                names.add(name)
 427                return name
 428
 429            arrays: t.List[exp.Condition] = []
 430            series_alias = new_name(taken_select_names, "pos")
 431            series = exp.alias_(
 432                exp.Unnest(
 433                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
 434                ),
 435                new_name(taken_source_names, "_u"),
 436                table=[series_alias],
 437            )
 438
 439            # we use list here because expression.selects is mutated inside the loop
 440            for select in list(expression.selects):
 441                explode = select.find(exp.Explode)
 442
 443                if explode:
 444                    pos_alias = ""
 445                    explode_alias = ""
 446
 447                    if isinstance(select, exp.Alias):
 448                        explode_alias = select.args["alias"]
 449                        alias = select
 450                    elif isinstance(select, exp.Aliases):
 451                        pos_alias = select.aliases[0]
 452                        explode_alias = select.aliases[1]
 453                        alias = select.replace(exp.alias_(select.this, "", copy=False))
 454                    else:
 455                        alias = select.replace(exp.alias_(select, ""))
 456                        explode = alias.find(exp.Explode)
 457                        assert explode
 458
 459                    is_posexplode = isinstance(explode, exp.Posexplode)
 460                    explode_arg = explode.this
 461
 462                    if isinstance(explode, exp.ExplodeOuter):
 463                        bracket = explode_arg[0]
 464                        bracket.set("safe", True)
 465                        bracket.set("offset", True)
 466                        explode_arg = exp.func(
 467                            "IF",
 468                            exp.func(
 469                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
 470                            ).eq(0),
 471                            exp.array(bracket, copy=False),
 472                            explode_arg,
 473                        )
 474
 475                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
 476                    if isinstance(explode_arg, exp.Column):
 477                        taken_select_names.add(explode_arg.output_name)
 478
 479                    unnest_source_alias = new_name(taken_source_names, "_u")
 480
 481                    if not explode_alias:
 482                        explode_alias = new_name(taken_select_names, "col")
 483
 484                        if is_posexplode:
 485                            pos_alias = new_name(taken_select_names, "pos")
 486
 487                    if not pos_alias:
 488                        pos_alias = new_name(taken_select_names, "pos")
 489
 490                    alias.set("alias", exp.to_identifier(explode_alias))
 491
 492                    series_table_alias = series.args["alias"].this
 493                    column = exp.If(
 494                        this=exp.column(series_alias, table=series_table_alias).eq(
 495                            exp.column(pos_alias, table=unnest_source_alias)
 496                        ),
 497                        true=exp.column(explode_alias, table=unnest_source_alias),
 498                    )
 499
 500                    explode.replace(column)
 501
 502                    if is_posexplode:
 503                        expressions = expression.expressions
 504                        expressions.insert(
 505                            expressions.index(alias) + 1,
 506                            exp.If(
 507                                this=exp.column(series_alias, table=series_table_alias).eq(
 508                                    exp.column(pos_alias, table=unnest_source_alias)
 509                                ),
 510                                true=exp.column(pos_alias, table=unnest_source_alias),
 511                            ).as_(pos_alias),
 512                        )
 513                        expression.set("expressions", expressions)
 514
 515                    if not arrays:
 516                        if expression.args.get("from"):
 517                            expression.join(series, copy=False, join_type="CROSS")
 518                        else:
 519                            expression.from_(series, copy=False)
 520
 521                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
 522                    arrays.append(size)
 523
 524                    # trino doesn't support left join unnest with on conditions
 525                    # if it did, this would be much simpler
 526                    expression.join(
 527                        exp.alias_(
 528                            exp.Unnest(
 529                                expressions=[explode_arg.copy()],
 530                                offset=exp.to_identifier(pos_alias),
 531                            ),
 532                            unnest_source_alias,
 533                            table=[explode_alias],
 534                        ),
 535                        join_type="CROSS",
 536                        copy=False,
 537                    )
 538
 539                    if index_offset != 1:
 540                        size = size - 1
 541
 542                    expression.where(
 543                        exp.column(series_alias, table=series_table_alias)
 544                        .eq(exp.column(pos_alias, table=unnest_source_alias))
 545                        .or_(
 546                            (exp.column(series_alias, table=series_table_alias) > size).and_(
 547                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
 548                            )
 549                        ),
 550                        copy=False,
 551                    )
 552
 553            if arrays:
 554                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
 555
 556                if index_offset != 1:
 557                    end = end - (1 - index_offset)
 558                series.expressions[0].set("end", end)
 559
 560        return expression
 561
 562    return _explode_projection_to_unnest
 563
 564
 565def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 566    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
 567    if (
 568        isinstance(expression, exp.PERCENTILES)
 569        and not isinstance(expression.parent, exp.WithinGroup)
 570        and expression.expression
 571    ):
 572        column = expression.this.pop()
 573        expression.set("this", expression.expression.pop())
 574        order = exp.Order(expressions=[exp.Ordered(this=column)])
 575        expression = exp.WithinGroup(this=expression, expression=order)
 576
 577    return expression
 578
 579
 580def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 581    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
 582    if (
 583        isinstance(expression, exp.WithinGroup)
 584        and isinstance(expression.this, exp.PERCENTILES)
 585        and isinstance(expression.expression, exp.Order)
 586    ):
 587        quantile = expression.this.this
 588        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
 589        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
 590
 591    return expression
 592
 593
 594def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
 595    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
 596    if isinstance(expression, exp.With) and expression.recursive:
 597        next_name = name_sequence("_c_")
 598
 599        for cte in expression.expressions:
 600            if not cte.args["alias"].columns:
 601                query = cte.this
 602                if isinstance(query, exp.SetOperation):
 603                    query = query.this
 604
 605                cte.args["alias"].set(
 606                    "columns",
 607                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
 608                )
 609
 610    return expression
 611
 612
 613def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
 614    """Replace 'epoch' in casts by the equivalent date literal."""
 615    if (
 616        isinstance(expression, (exp.Cast, exp.TryCast))
 617        and expression.name.lower() == "epoch"
 618        and expression.to.this in exp.DataType.TEMPORAL_TYPES
 619    ):
 620        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
 621
 622    return expression
 623
 624
 625def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
 626    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
 627    if isinstance(expression, exp.Select):
 628        for join in expression.args.get("joins") or []:
 629            on = join.args.get("on")
 630            if on and join.kind in ("SEMI", "ANTI"):
 631                subquery = exp.select("1").from_(join.this).where(on)
 632                exists = exp.Exists(this=subquery)
 633                if join.kind == "ANTI":
 634                    exists = exists.not_(copy=False)
 635
 636                join.pop()
 637                expression.where(exists, copy=False)
 638
 639    return expression
 640
 641
 642def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
 643    """
 644    Converts a query with a FULL OUTER join to a union of identical queries that
 645    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
 646    for queries that have a single FULL OUTER join.
 647    """
 648    if isinstance(expression, exp.Select):
 649        full_outer_joins = [
 650            (index, join)
 651            for index, join in enumerate(expression.args.get("joins") or [])
 652            if join.side == "FULL"
 653        ]
 654
 655        if len(full_outer_joins) == 1:
 656            expression_copy = expression.copy()
 657            expression.set("limit", None)
 658            index, full_outer_join = full_outer_joins[0]
 659
 660            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
 661            join_conditions = full_outer_join.args.get("on") or exp.and_(
 662                *[
 663                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
 664                    for col in full_outer_join.args.get("using")
 665                ]
 666            )
 667
 668            full_outer_join.set("side", "left")
 669            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
 670            expression_copy.args["joins"][index].set("side", "right")
 671            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
 672            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
 673            expression.args.pop("order", None)  # remove order by from LEFT side
 674
 675            return exp.union(expression, expression_copy, copy=False, distinct=False)
 676
 677    return expression
 678
 679
 680def move_ctes_to_top_level(expression: E) -> E:
 681    """
 682    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
 683    defined at the top-level, so for example queries like:
 684
 685        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
 686
 687    are invalid in those dialects. This transformation can be used to ensure all CTEs are
 688    moved to the top level so that the final SQL code is valid from a syntax standpoint.
 689
 690    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
 691    """
 692    top_level_with = expression.args.get("with")
 693    for inner_with in expression.find_all(exp.With):
 694        if inner_with.parent is expression:
 695            continue
 696
 697        if not top_level_with:
 698            top_level_with = inner_with.pop()
 699            expression.set("with", top_level_with)
 700        else:
 701            if inner_with.recursive:
 702                top_level_with.set("recursive", True)
 703
 704            parent_cte = inner_with.find_ancestor(exp.CTE)
 705            inner_with.pop()
 706
 707            if parent_cte:
 708                i = top_level_with.expressions.index(parent_cte)
 709                top_level_with.expressions[i:i] = inner_with.expressions
 710                top_level_with.set("expressions", top_level_with.expressions)
 711            else:
 712                top_level_with.set(
 713                    "expressions", top_level_with.expressions + inner_with.expressions
 714                )
 715
 716    return expression
 717
 718
 719def ensure_bools(expression: exp.Expression) -> exp.Expression:
 720    """Converts numeric values used in conditions into explicit boolean expressions."""
 721    from sqlglot.optimizer.canonicalize import ensure_bools
 722
 723    def _ensure_bool(node: exp.Expression) -> None:
 724        if (
 725            node.is_number
 726            or (
 727                not isinstance(node, exp.SubqueryPredicate)
 728                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
 729            )
 730            or (isinstance(node, exp.Column) and not node.type)
 731        ):
 732            node.replace(node.neq(0))
 733
 734    for node in expression.walk():
 735        ensure_bools(node, _ensure_bool)
 736
 737    return expression
 738
 739
 740def unqualify_columns(expression: exp.Expression) -> exp.Expression:
 741    for column in expression.find_all(exp.Column):
 742        # We only wanna pop off the table, db, catalog args
 743        for part in column.parts[:-1]:
 744            part.pop()
 745
 746    return expression
 747
 748
 749def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
 750    assert isinstance(expression, exp.Create)
 751    for constraint in expression.find_all(exp.UniqueColumnConstraint):
 752        if constraint.parent:
 753            constraint.parent.pop()
 754
 755    return expression
 756
 757
 758def ctas_with_tmp_tables_to_create_tmp_view(
 759    expression: exp.Expression,
 760    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
 761) -> exp.Expression:
 762    assert isinstance(expression, exp.Create)
 763    properties = expression.args.get("properties")
 764    temporary = any(
 765        isinstance(prop, exp.TemporaryProperty)
 766        for prop in (properties.expressions if properties else [])
 767    )
 768
 769    # CTAS with temp tables map to CREATE TEMPORARY VIEW
 770    if expression.kind == "TABLE" and temporary:
 771        if expression.expression:
 772            return exp.Create(
 773                kind="TEMPORARY VIEW",
 774                this=expression.this,
 775                expression=expression.expression,
 776            )
 777        return tmp_storage_provider(expression)
 778
 779    return expression
 780
 781
 782def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
 783    """
 784    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
 785    PARTITIONED BY value is an array of column names, they are transformed into a schema.
 786    The corresponding columns are removed from the create statement.
 787    """
 788    assert isinstance(expression, exp.Create)
 789    has_schema = isinstance(expression.this, exp.Schema)
 790    is_partitionable = expression.kind in {"TABLE", "VIEW"}
 791
 792    if has_schema and is_partitionable:
 793        prop = expression.find(exp.PartitionedByProperty)
 794        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 795            schema = expression.this
 796            columns = {v.name.upper() for v in prop.this.expressions}
 797            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 798            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 799            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 800            expression.set("this", schema)
 801
 802    return expression
 803
 804
 805def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
 806    """
 807    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
 808
 809    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
 810    """
 811    assert isinstance(expression, exp.Create)
 812    prop = expression.find(exp.PartitionedByProperty)
 813    if (
 814        prop
 815        and prop.this
 816        and isinstance(prop.this, exp.Schema)
 817        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
 818    ):
 819        prop_this = exp.Tuple(
 820            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
 821        )
 822        schema = expression.this
 823        for e in prop.this.expressions:
 824            schema.append("expressions", e)
 825        prop.set("this", prop_this)
 826
 827    return expression
 828
 829
 830def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
 831    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
 832    if isinstance(expression, exp.Struct):
 833        expression.set(
 834            "expressions",
 835            [
 836                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
 837                for e in expression.expressions
 838            ],
 839        )
 840
 841    return expression
 842
 843
 844def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
 845    """
 846    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
 847    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
 848
 849    For example,
 850        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
 851        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
 852
 853    Args:
 854        expression: The AST to remove join marks from.
 855
 856    Returns:
 857       The AST with join marks removed.
 858    """
 859    from sqlglot.optimizer.scope import traverse_scope
 860
 861    for scope in traverse_scope(expression):
 862        query = scope.expression
 863
 864        where = query.args.get("where")
 865        joins = query.args.get("joins")
 866
 867        if not where or not joins:
 868            continue
 869
 870        query_from = query.args["from"]
 871
 872        # These keep track of the joins to be replaced
 873        new_joins: t.Dict[str, exp.Join] = {}
 874        old_joins = {join.alias_or_name: join for join in joins}
 875
 876        for column in scope.columns:
 877            if not column.args.get("join_mark"):
 878                continue
 879
 880            predicate = column.find_ancestor(exp.Predicate, exp.Select)
 881            assert isinstance(
 882                predicate, exp.Binary
 883            ), "Columns can only be marked with (+) when involved in a binary operation"
 884
 885            predicate_parent = predicate.parent
 886            join_predicate = predicate.pop()
 887
 888            left_columns = [
 889                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
 890            ]
 891            right_columns = [
 892                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
 893            ]
 894
 895            assert not (
 896                left_columns and right_columns
 897            ), "The (+) marker cannot appear in both sides of a binary predicate"
 898
 899            marked_column_tables = set()
 900            for col in left_columns or right_columns:
 901                table = col.table
 902                assert table, f"Column {col} needs to be qualified with a table"
 903
 904                col.set("join_mark", False)
 905                marked_column_tables.add(table)
 906
 907            assert (
 908                len(marked_column_tables) == 1
 909            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
 910
 911            # Add predicate if join already copied, or add join if it is new
 912            join_this = old_joins.get(col.table, query_from).this
 913            existing_join = new_joins.get(join_this.alias_or_name)
 914            if existing_join:
 915                existing_join.set("on", exp.and_(existing_join.args["on"], join_predicate))
 916            else:
 917                new_joins[join_this.alias_or_name] = exp.Join(
 918                    this=join_this.copy(), on=join_predicate.copy(), kind="LEFT"
 919                )
 920
 921            # If the parent of the target predicate is a binary node, then it now has only one child
 922            if isinstance(predicate_parent, exp.Binary):
 923                if predicate_parent.left is None:
 924                    predicate_parent.replace(predicate_parent.right)
 925                else:
 926                    predicate_parent.replace(predicate_parent.left)
 927
 928        only_old_join_sources = old_joins.keys() - new_joins.keys()
 929
 930        if query_from.alias_or_name in new_joins:
 931            assert (
 932                len(only_old_join_sources) >= 1
 933            ), "Cannot determine which table to use in the new FROM clause"
 934
 935            new_from_name = list(only_old_join_sources)[0]
 936            query.set("from", exp.From(this=old_joins.pop(new_from_name).this))
 937            only_old_join_sources.remove(new_from_name)
 938
 939        if new_joins:
 940            only_old_join_expressions = []
 941            for old_join_source in only_old_join_sources:
 942                old_join_expression = old_joins[old_join_source]
 943                if not old_join_expression.kind:
 944                    old_join_expression.set("kind", "CROSS")
 945
 946                only_old_join_expressions.append(old_join_expression)
 947
 948            query.set("joins", list(new_joins.values()) + only_old_join_expressions)
 949
 950        if not where.this:
 951            where.pop()
 952
 953    return expression
 954
 955
 956def any_to_exists(expression: exp.Expression) -> exp.Expression:
 957    """
 958    Transform ANY operator to Spark's EXISTS
 959
 960    For example,
 961        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
 962        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
 963
 964    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
 965    transformation
 966    """
 967    if isinstance(expression, exp.Select):
 968        for any_expr in expression.find_all(exp.Any):
 969            this = any_expr.this
 970            if isinstance(this, exp.Query):
 971                continue
 972
 973            binop = any_expr.parent
 974            if isinstance(binop, exp.Binary):
 975                lambda_arg = exp.to_identifier("x")
 976                any_expr.replace(lambda_arg)
 977                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
 978                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
 979
 980    return expression
 981
 982
 983def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
 984    """Eliminates the `WINDOW` query clause by inling each named window."""
 985    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 986        from sqlglot.optimizer.scope import find_all_in_scope
 987
 988        windows = expression.args["windows"]
 989        expression.set("windows", None)
 990
 991        window_expression: t.Dict[str, exp.Expression] = {}
 992
 993        def _inline_inherited_window(window: exp.Expression) -> None:
 994            inherited_window = window_expression.get(window.alias.lower())
 995            if not inherited_window:
 996                return
 997
 998            window.set("alias", None)
 999            for key in ("partition_by", "order", "spec"):
1000                arg = inherited_window.args.get(key)
1001                if arg:
1002                    window.set(key, arg.copy())
1003
1004        for window in windows:
1005            _inline_inherited_window(window)
1006            window_expression[window.name.lower()] = window
1007
1008        for window in find_all_in_scope(expression, exp.Window):
1009            _inline_inherited_window(window)
1010
1011    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
184    ):
185        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
186
187        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189
190        order = expression.args.get("order")
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number_window_alias)
197        expression.select(window, copy=False)
198
199        # We add aliases to the projections so that we can safely reference them in the outer query
200        new_selects = []
201        taken_names = {row_number_window_alias}
202        for select in expression.selects[:-1]:
203            if select.is_star:
204                new_selects = [exp.Star()]
205                break
206
207            if not isinstance(select, exp.Alias):
208                alias = find_new_name(taken_names, select.output_name or "_col")
209                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
210                select = select.replace(exp.alias_(select, alias, quoted=quoted))
211
212            taken_names.add(select.output_name)
213            new_selects.append(select.args["alias"])
214
215        return (
216            exp.select(*new_selects, copy=False)
217            .from_(expression.subquery("_t", copy=False), copy=False)
218            .where(exp.column(row_number_window_alias).eq(1), copy=False)
219        )
220
221    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
224def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
225    """
226    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
227
228    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
229    https://docs.snowflake.com/en/sql-reference/constructs/qualify
230
231    Some dialects don't support window functions in the WHERE clause, so we need to include them as
232    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
233    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
234    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
235    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
236    corresponding expression to avoid creating invalid column references.
237    """
238    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
239        taken = set(expression.named_selects)
240        for select in expression.selects:
241            if not select.alias_or_name:
242                alias = find_new_name(taken, "_c")
243                select.replace(exp.alias_(select, alias))
244                taken.add(alias)
245
246        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
247            alias_or_name = select.alias_or_name
248            identifier = select.args.get("alias") or select.this
249            if isinstance(identifier, exp.Identifier):
250                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
251            return alias_or_name
252
253        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
254        qualify_filters = expression.args["qualify"].pop().this
255        expression_by_alias = {
256            select.alias: select.this
257            for select in expression.selects
258            if isinstance(select, exp.Alias)
259        }
260
261        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
262        for select_candidate in list(qualify_filters.find_all(select_candidates)):
263            if isinstance(select_candidate, exp.Window):
264                if expression_by_alias:
265                    for column in select_candidate.find_all(exp.Column):
266                        expr = expression_by_alias.get(column.name)
267                        if expr:
268                            column.replace(expr)
269
270                alias = find_new_name(expression.named_selects, "_w")
271                expression.select(exp.alias_(select_candidate, alias), copy=False)
272                column = exp.column(alias)
273
274                if isinstance(select_candidate.parent, exp.Qualify):
275                    qualify_filters = column
276                else:
277                    select_candidate.replace(column)
278            elif select_candidate.name not in expression.named_selects:
279                expression.select(select_candidate.copy(), copy=False)
280
281        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
282            qualify_filters, copy=False
283        )
284
285    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
288def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
289    """
290    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
291    other expressions. This transforms removes the precision from parameterized types in expressions.
292    """
293    for node in expression.find_all(exp.DataType):
294        node.set(
295            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
296        )
297
298    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
301def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
302    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
303    from sqlglot.optimizer.scope import find_all_in_scope
304
305    if isinstance(expression, exp.Select):
306        unnest_aliases = {
307            unnest.alias
308            for unnest in find_all_in_scope(expression, exp.Unnest)
309            if isinstance(unnest.parent, (exp.From, exp.Join))
310        }
311        if unnest_aliases:
312            for column in expression.find_all(exp.Column):
313                leftmost_part = column.parts[0]
314                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
315                    leftmost_part.pop()
316
317    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
320def unnest_to_explode(
321    expression: exp.Expression,
322    unnest_using_arrays_zip: bool = True,
323) -> exp.Expression:
324    """Convert cross join unnest into lateral view explode."""
325
326    def _unnest_zip_exprs(
327        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
328    ) -> t.List[exp.Expression]:
329        if has_multi_expr:
330            if not unnest_using_arrays_zip:
331                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
332
333            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
334            zip_exprs: t.List[exp.Expression] = [
335                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
336            ]
337            u.set("expressions", zip_exprs)
338            return zip_exprs
339        return unnest_exprs
340
341    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
342        if u.args.get("offset"):
343            return exp.Posexplode
344        return exp.Inline if has_multi_expr else exp.Explode
345
346    if isinstance(expression, exp.Select):
347        from_ = expression.args.get("from")
348
349        if from_ and isinstance(from_.this, exp.Unnest):
350            unnest = from_.this
351            alias = unnest.args.get("alias")
352            exprs = unnest.expressions
353            has_multi_expr = len(exprs) > 1
354            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
355
356            unnest.replace(
357                exp.Table(
358                    this=_udtf_type(unnest, has_multi_expr)(
359                        this=this,
360                        expressions=expressions,
361                    ),
362                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
363                )
364            )
365
366        joins = expression.args.get("joins") or []
367        for join in list(joins):
368            join_expr = join.this
369
370            is_lateral = isinstance(join_expr, exp.Lateral)
371
372            unnest = join_expr.this if is_lateral else join_expr
373
374            if isinstance(unnest, exp.Unnest):
375                if is_lateral:
376                    alias = join_expr.args.get("alias")
377                else:
378                    alias = unnest.args.get("alias")
379                exprs = unnest.expressions
380                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
381                has_multi_expr = len(exprs) > 1
382                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
383
384                joins.remove(join)
385
386                alias_cols = alias.columns if alias else []
387
388                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
389                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
390                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
391
392                if not has_multi_expr and len(alias_cols) not in (1, 2):
393                    raise UnsupportedError(
394                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
395                    )
396
397                for e, column in zip(exprs, alias_cols):
398                    expression.append(
399                        "laterals",
400                        exp.Lateral(
401                            this=_udtf_type(unnest, has_multi_expr)(this=e),
402                            view=True,
403                            alias=exp.TableAlias(
404                                this=alias.this,  # type: ignore
405                                columns=alias_cols,
406                            ),
407                        ),
408                    )
409
410    return expression

Convert cross join unnest into lateral view explode.

def explode_projection_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
413def explode_projection_to_unnest(
414    index_offset: int = 0,
415) -> t.Callable[[exp.Expression], exp.Expression]:
416    """Convert explode/posexplode projections into unnests."""
417
418    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
419        if isinstance(expression, exp.Select):
420            from sqlglot.optimizer.scope import Scope
421
422            taken_select_names = set(expression.named_selects)
423            taken_source_names = {name for name, _ in Scope(expression).references}
424
425            def new_name(names: t.Set[str], name: str) -> str:
426                name = find_new_name(names, name)
427                names.add(name)
428                return name
429
430            arrays: t.List[exp.Condition] = []
431            series_alias = new_name(taken_select_names, "pos")
432            series = exp.alias_(
433                exp.Unnest(
434                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
435                ),
436                new_name(taken_source_names, "_u"),
437                table=[series_alias],
438            )
439
440            # we use list here because expression.selects is mutated inside the loop
441            for select in list(expression.selects):
442                explode = select.find(exp.Explode)
443
444                if explode:
445                    pos_alias = ""
446                    explode_alias = ""
447
448                    if isinstance(select, exp.Alias):
449                        explode_alias = select.args["alias"]
450                        alias = select
451                    elif isinstance(select, exp.Aliases):
452                        pos_alias = select.aliases[0]
453                        explode_alias = select.aliases[1]
454                        alias = select.replace(exp.alias_(select.this, "", copy=False))
455                    else:
456                        alias = select.replace(exp.alias_(select, ""))
457                        explode = alias.find(exp.Explode)
458                        assert explode
459
460                    is_posexplode = isinstance(explode, exp.Posexplode)
461                    explode_arg = explode.this
462
463                    if isinstance(explode, exp.ExplodeOuter):
464                        bracket = explode_arg[0]
465                        bracket.set("safe", True)
466                        bracket.set("offset", True)
467                        explode_arg = exp.func(
468                            "IF",
469                            exp.func(
470                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
471                            ).eq(0),
472                            exp.array(bracket, copy=False),
473                            explode_arg,
474                        )
475
476                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
477                    if isinstance(explode_arg, exp.Column):
478                        taken_select_names.add(explode_arg.output_name)
479
480                    unnest_source_alias = new_name(taken_source_names, "_u")
481
482                    if not explode_alias:
483                        explode_alias = new_name(taken_select_names, "col")
484
485                        if is_posexplode:
486                            pos_alias = new_name(taken_select_names, "pos")
487
488                    if not pos_alias:
489                        pos_alias = new_name(taken_select_names, "pos")
490
491                    alias.set("alias", exp.to_identifier(explode_alias))
492
493                    series_table_alias = series.args["alias"].this
494                    column = exp.If(
495                        this=exp.column(series_alias, table=series_table_alias).eq(
496                            exp.column(pos_alias, table=unnest_source_alias)
497                        ),
498                        true=exp.column(explode_alias, table=unnest_source_alias),
499                    )
500
501                    explode.replace(column)
502
503                    if is_posexplode:
504                        expressions = expression.expressions
505                        expressions.insert(
506                            expressions.index(alias) + 1,
507                            exp.If(
508                                this=exp.column(series_alias, table=series_table_alias).eq(
509                                    exp.column(pos_alias, table=unnest_source_alias)
510                                ),
511                                true=exp.column(pos_alias, table=unnest_source_alias),
512                            ).as_(pos_alias),
513                        )
514                        expression.set("expressions", expressions)
515
516                    if not arrays:
517                        if expression.args.get("from"):
518                            expression.join(series, copy=False, join_type="CROSS")
519                        else:
520                            expression.from_(series, copy=False)
521
522                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
523                    arrays.append(size)
524
525                    # trino doesn't support left join unnest with on conditions
526                    # if it did, this would be much simpler
527                    expression.join(
528                        exp.alias_(
529                            exp.Unnest(
530                                expressions=[explode_arg.copy()],
531                                offset=exp.to_identifier(pos_alias),
532                            ),
533                            unnest_source_alias,
534                            table=[explode_alias],
535                        ),
536                        join_type="CROSS",
537                        copy=False,
538                    )
539
540                    if index_offset != 1:
541                        size = size - 1
542
543                    expression.where(
544                        exp.column(series_alias, table=series_table_alias)
545                        .eq(exp.column(pos_alias, table=unnest_source_alias))
546                        .or_(
547                            (exp.column(series_alias, table=series_table_alias) > size).and_(
548                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
549                            )
550                        ),
551                        copy=False,
552                    )
553
554            if arrays:
555                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
556
557                if index_offset != 1:
558                    end = end - (1 - index_offset)
559                series.expressions[0].set("end", end)
560
561        return expression
562
563    return _explode_projection_to_unnest

Convert explode/posexplode projections into unnests.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
566def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
567    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
568    if (
569        isinstance(expression, exp.PERCENTILES)
570        and not isinstance(expression.parent, exp.WithinGroup)
571        and expression.expression
572    ):
573        column = expression.this.pop()
574        expression.set("this", expression.expression.pop())
575        order = exp.Order(expressions=[exp.Ordered(this=column)])
576        expression = exp.WithinGroup(this=expression, expression=order)
577
578    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
581def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
582    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
583    if (
584        isinstance(expression, exp.WithinGroup)
585        and isinstance(expression.this, exp.PERCENTILES)
586        and isinstance(expression.expression, exp.Order)
587    ):
588        quantile = expression.this.this
589        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
590        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
591
592    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
595def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
596    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
597    if isinstance(expression, exp.With) and expression.recursive:
598        next_name = name_sequence("_c_")
599
600        for cte in expression.expressions:
601            if not cte.args["alias"].columns:
602                query = cte.this
603                if isinstance(query, exp.SetOperation):
604                    query = query.this
605
606                cte.args["alias"].set(
607                    "columns",
608                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
609                )
610
611    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
614def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
615    """Replace 'epoch' in casts by the equivalent date literal."""
616    if (
617        isinstance(expression, (exp.Cast, exp.TryCast))
618        and expression.name.lower() == "epoch"
619        and expression.to.this in exp.DataType.TEMPORAL_TYPES
620    ):
621        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
622
623    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
626def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
627    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
628    if isinstance(expression, exp.Select):
629        for join in expression.args.get("joins") or []:
630            on = join.args.get("on")
631            if on and join.kind in ("SEMI", "ANTI"):
632                subquery = exp.select("1").from_(join.this).where(on)
633                exists = exp.Exists(this=subquery)
634                if join.kind == "ANTI":
635                    exists = exists.not_(copy=False)
636
637                join.pop()
638                expression.where(exists, copy=False)
639
640    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
643def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
644    """
645    Converts a query with a FULL OUTER join to a union of identical queries that
646    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
647    for queries that have a single FULL OUTER join.
648    """
649    if isinstance(expression, exp.Select):
650        full_outer_joins = [
651            (index, join)
652            for index, join in enumerate(expression.args.get("joins") or [])
653            if join.side == "FULL"
654        ]
655
656        if len(full_outer_joins) == 1:
657            expression_copy = expression.copy()
658            expression.set("limit", None)
659            index, full_outer_join = full_outer_joins[0]
660
661            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
662            join_conditions = full_outer_join.args.get("on") or exp.and_(
663                *[
664                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
665                    for col in full_outer_join.args.get("using")
666                ]
667            )
668
669            full_outer_join.set("side", "left")
670            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
671            expression_copy.args["joins"][index].set("side", "right")
672            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
673            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
674            expression.args.pop("order", None)  # remove order by from LEFT side
675
676            return exp.union(expression, expression_copy, copy=False, distinct=False)
677
678    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
681def move_ctes_to_top_level(expression: E) -> E:
682    """
683    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
684    defined at the top-level, so for example queries like:
685
686        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
687
688    are invalid in those dialects. This transformation can be used to ensure all CTEs are
689    moved to the top level so that the final SQL code is valid from a syntax standpoint.
690
691    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
692    """
693    top_level_with = expression.args.get("with")
694    for inner_with in expression.find_all(exp.With):
695        if inner_with.parent is expression:
696            continue
697
698        if not top_level_with:
699            top_level_with = inner_with.pop()
700            expression.set("with", top_level_with)
701        else:
702            if inner_with.recursive:
703                top_level_with.set("recursive", True)
704
705            parent_cte = inner_with.find_ancestor(exp.CTE)
706            inner_with.pop()
707
708            if parent_cte:
709                i = top_level_with.expressions.index(parent_cte)
710                top_level_with.expressions[i:i] = inner_with.expressions
711                top_level_with.set("expressions", top_level_with.expressions)
712            else:
713                top_level_with.set(
714                    "expressions", top_level_with.expressions + inner_with.expressions
715                )
716
717    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
720def ensure_bools(expression: exp.Expression) -> exp.Expression:
721    """Converts numeric values used in conditions into explicit boolean expressions."""
722    from sqlglot.optimizer.canonicalize import ensure_bools
723
724    def _ensure_bool(node: exp.Expression) -> None:
725        if (
726            node.is_number
727            or (
728                not isinstance(node, exp.SubqueryPredicate)
729                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
730            )
731            or (isinstance(node, exp.Column) and not node.type)
732        ):
733            node.replace(node.neq(0))
734
735    for node in expression.walk():
736        ensure_bools(node, _ensure_bool)
737
738    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
741def unqualify_columns(expression: exp.Expression) -> exp.Expression:
742    for column in expression.find_all(exp.Column):
743        # We only wanna pop off the table, db, catalog args
744        for part in column.parts[:-1]:
745            part.pop()
746
747    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
750def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
751    assert isinstance(expression, exp.Create)
752    for constraint in expression.find_all(exp.UniqueColumnConstraint):
753        if constraint.parent:
754            constraint.parent.pop()
755
756    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
759def ctas_with_tmp_tables_to_create_tmp_view(
760    expression: exp.Expression,
761    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
762) -> exp.Expression:
763    assert isinstance(expression, exp.Create)
764    properties = expression.args.get("properties")
765    temporary = any(
766        isinstance(prop, exp.TemporaryProperty)
767        for prop in (properties.expressions if properties else [])
768    )
769
770    # CTAS with temp tables map to CREATE TEMPORARY VIEW
771    if expression.kind == "TABLE" and temporary:
772        if expression.expression:
773            return exp.Create(
774                kind="TEMPORARY VIEW",
775                this=expression.this,
776                expression=expression.expression,
777            )
778        return tmp_storage_provider(expression)
779
780    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
783def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
784    """
785    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
786    PARTITIONED BY value is an array of column names, they are transformed into a schema.
787    The corresponding columns are removed from the create statement.
788    """
789    assert isinstance(expression, exp.Create)
790    has_schema = isinstance(expression.this, exp.Schema)
791    is_partitionable = expression.kind in {"TABLE", "VIEW"}
792
793    if has_schema and is_partitionable:
794        prop = expression.find(exp.PartitionedByProperty)
795        if prop and prop.this and not isinstance(prop.this, exp.Schema):
796            schema = expression.this
797            columns = {v.name.upper() for v in prop.this.expressions}
798            partitions = [col for col in schema.expressions if col.name.upper() in columns]
799            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
800            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
801            expression.set("this", schema)
802
803    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
806def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
807    """
808    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
809
810    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
811    """
812    assert isinstance(expression, exp.Create)
813    prop = expression.find(exp.PartitionedByProperty)
814    if (
815        prop
816        and prop.this
817        and isinstance(prop.this, exp.Schema)
818        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
819    ):
820        prop_this = exp.Tuple(
821            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
822        )
823        schema = expression.this
824        for e in prop.this.expressions:
825            schema.append("expressions", e)
826        prop.set("this", prop_this)
827
828    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
831def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
832    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
833    if isinstance(expression, exp.Struct):
834        expression.set(
835            "expressions",
836            [
837                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
838                for e in expression.expressions
839            ],
840        )
841
842    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
845def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
846    """
847    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
848    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
849
850    For example,
851        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
852        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
853
854    Args:
855        expression: The AST to remove join marks from.
856
857    Returns:
858       The AST with join marks removed.
859    """
860    from sqlglot.optimizer.scope import traverse_scope
861
862    for scope in traverse_scope(expression):
863        query = scope.expression
864
865        where = query.args.get("where")
866        joins = query.args.get("joins")
867
868        if not where or not joins:
869            continue
870
871        query_from = query.args["from"]
872
873        # These keep track of the joins to be replaced
874        new_joins: t.Dict[str, exp.Join] = {}
875        old_joins = {join.alias_or_name: join for join in joins}
876
877        for column in scope.columns:
878            if not column.args.get("join_mark"):
879                continue
880
881            predicate = column.find_ancestor(exp.Predicate, exp.Select)
882            assert isinstance(
883                predicate, exp.Binary
884            ), "Columns can only be marked with (+) when involved in a binary operation"
885
886            predicate_parent = predicate.parent
887            join_predicate = predicate.pop()
888
889            left_columns = [
890                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
891            ]
892            right_columns = [
893                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
894            ]
895
896            assert not (
897                left_columns and right_columns
898            ), "The (+) marker cannot appear in both sides of a binary predicate"
899
900            marked_column_tables = set()
901            for col in left_columns or right_columns:
902                table = col.table
903                assert table, f"Column {col} needs to be qualified with a table"
904
905                col.set("join_mark", False)
906                marked_column_tables.add(table)
907
908            assert (
909                len(marked_column_tables) == 1
910            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
911
912            # Add predicate if join already copied, or add join if it is new
913            join_this = old_joins.get(col.table, query_from).this
914            existing_join = new_joins.get(join_this.alias_or_name)
915            if existing_join:
916                existing_join.set("on", exp.and_(existing_join.args["on"], join_predicate))
917            else:
918                new_joins[join_this.alias_or_name] = exp.Join(
919                    this=join_this.copy(), on=join_predicate.copy(), kind="LEFT"
920                )
921
922            # If the parent of the target predicate is a binary node, then it now has only one child
923            if isinstance(predicate_parent, exp.Binary):
924                if predicate_parent.left is None:
925                    predicate_parent.replace(predicate_parent.right)
926                else:
927                    predicate_parent.replace(predicate_parent.left)
928
929        only_old_join_sources = old_joins.keys() - new_joins.keys()
930
931        if query_from.alias_or_name in new_joins:
932            assert (
933                len(only_old_join_sources) >= 1
934            ), "Cannot determine which table to use in the new FROM clause"
935
936            new_from_name = list(only_old_join_sources)[0]
937            query.set("from", exp.From(this=old_joins.pop(new_from_name).this))
938            only_old_join_sources.remove(new_from_name)
939
940        if new_joins:
941            only_old_join_expressions = []
942            for old_join_source in only_old_join_sources:
943                old_join_expression = old_joins[old_join_source]
944                if not old_join_expression.kind:
945                    old_join_expression.set("kind", "CROSS")
946
947                only_old_join_expressions.append(old_join_expression)
948
949            query.set("joins", list(new_joins.values()) + only_old_join_expressions)
950
951        if not where.this:
952            where.pop()
953
954    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
957def any_to_exists(expression: exp.Expression) -> exp.Expression:
958    """
959    Transform ANY operator to Spark's EXISTS
960
961    For example,
962        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
963        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
964
965    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
966    transformation
967    """
968    if isinstance(expression, exp.Select):
969        for any_expr in expression.find_all(exp.Any):
970            this = any_expr.this
971            if isinstance(this, exp.Query):
972                continue
973
974            binop = any_expr.parent
975            if isinstance(binop, exp.Binary):
976                lambda_arg = exp.to_identifier("x")
977                any_expr.replace(lambda_arg)
978                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
979                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
980
981    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation

def eliminate_window_clause( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 984def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
 985    """Eliminates the `WINDOW` query clause by inling each named window."""
 986    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 987        from sqlglot.optimizer.scope import find_all_in_scope
 988
 989        windows = expression.args["windows"]
 990        expression.set("windows", None)
 991
 992        window_expression: t.Dict[str, exp.Expression] = {}
 993
 994        def _inline_inherited_window(window: exp.Expression) -> None:
 995            inherited_window = window_expression.get(window.alias.lower())
 996            if not inherited_window:
 997                return
 998
 999            window.set("alias", None)
1000            for key in ("partition_by", "order", "spec"):
1001                arg = inherited_window.args.get(key)
1002                if arg:
1003                    window.set(key, arg.copy())
1004
1005        for window in windows:
1006            _inline_inherited_window(window)
1007            window_expression[window.name.lower()] = window
1008
1009        for window in find_all_in_scope(expression, exp.Window):
1010            _inline_inherited_window(window)
1011
1012    return expression

Eliminates the WINDOW query clause by inling each named window.