Edit on GitHub

sqlglot.transforms

   1from __future__ import annotations
   2
   3import typing as t
   4
   5from sqlglot import expressions as exp
   6from sqlglot.errors import UnsupportedError
   7from sqlglot.helper import find_new_name, name_sequence
   8
   9
  10if t.TYPE_CHECKING:
  11    from sqlglot._typing import E
  12    from sqlglot.generator import Generator
  13
  14
  15def preprocess(
  16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
  17) -> t.Callable[[Generator, exp.Expression], str]:
  18    """
  19    Creates a new transform by chaining a sequence of transformations and converts the resulting
  20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
  21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
  22
  23    Args:
  24        transforms: sequence of transform functions. These will be called in order.
  25
  26    Returns:
  27        Function that can be used as a generator transform.
  28    """
  29
  30    def _to_sql(self, expression: exp.Expression) -> str:
  31        expression_type = type(expression)
  32
  33        try:
  34            expression = transforms[0](expression)
  35            for transform in transforms[1:]:
  36                expression = transform(expression)
  37        except UnsupportedError as unsupported_error:
  38            self.unsupported(str(unsupported_error))
  39
  40        _sql_handler = getattr(self, expression.key + "_sql", None)
  41        if _sql_handler:
  42            return _sql_handler(expression)
  43
  44        transforms_handler = self.TRANSFORMS.get(type(expression))
  45        if transforms_handler:
  46            if expression_type is type(expression):
  47                if isinstance(expression, exp.Func):
  48                    return self.function_fallback_sql(expression)
  49
  50                # Ensures we don't enter an infinite loop. This can happen when the original expression
  51                # has the same type as the final expression and there's no _sql method available for it,
  52                # because then it'd re-enter _to_sql.
  53                raise ValueError(
  54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
  55                )
  56
  57            return transforms_handler(self, expression)
  58
  59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
  60
  61    return _to_sql
  62
  63
  64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
  65    if isinstance(expression, exp.Select):
  66        count = 0
  67        recursive_ctes = []
  68
  69        for unnest in expression.find_all(exp.Unnest):
  70            if (
  71                not isinstance(unnest.parent, (exp.From, exp.Join))
  72                or len(unnest.expressions) != 1
  73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
  74            ):
  75                continue
  76
  77            generate_date_array = unnest.expressions[0]
  78            start = generate_date_array.args.get("start")
  79            end = generate_date_array.args.get("end")
  80            step = generate_date_array.args.get("step")
  81
  82            if not start or not end or not isinstance(step, exp.Interval):
  83                continue
  84
  85            alias = unnest.args.get("alias")
  86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
  87
  88            start = exp.cast(start, "date")
  89            date_add = exp.func(
  90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
  91            )
  92            cast_date_add = exp.cast(date_add, "date")
  93
  94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
  95
  96            base_query = exp.select(start.as_(column_name))
  97            recursive_query = (
  98                exp.select(cast_date_add)
  99                .from_(cte_name)
 100                .where(cast_date_add <= exp.cast(end, "date"))
 101            )
 102            cte_query = base_query.union(recursive_query, distinct=False)
 103
 104            generate_dates_query = exp.select(column_name).from_(cte_name)
 105            unnest.replace(generate_dates_query.subquery(cte_name))
 106
 107            recursive_ctes.append(
 108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
 109            )
 110            count += 1
 111
 112        if recursive_ctes:
 113            with_expression = expression.args.get("with") or exp.With()
 114            with_expression.set("recursive", True)
 115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
 116            expression.set("with", with_expression)
 117
 118    return expression
 119
 120
 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
 122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
 123    this = expression.this
 124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
 125        unnest = exp.Unnest(expressions=[this])
 126        if expression.alias:
 127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
 128
 129        return unnest
 130
 131    return expression
 132
 133
 134def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 135    """
 136    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 137
 138    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 139
 140    Args:
 141        expression: the expression that will be transformed.
 142
 143    Returns:
 144        The transformed expression.
 145    """
 146    if (
 147        isinstance(expression, exp.Select)
 148        and expression.args.get("distinct")
 149        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
 150    ):
 151        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
 152
 153        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 154        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 155
 156        order = expression.args.get("order")
 157        if order:
 158            window.set("order", order.pop())
 159        else:
 160            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 161
 162        window = exp.alias_(window, row_number_window_alias)
 163        expression.select(window, copy=False)
 164
 165        # We add aliases to the projections so that we can safely reference them in the outer query
 166        new_selects = []
 167        taken_names = {row_number_window_alias}
 168        for select in expression.selects[:-1]:
 169            if select.is_star:
 170                new_selects = [exp.Star()]
 171                break
 172
 173            if not isinstance(select, exp.Alias):
 174                alias = find_new_name(taken_names, select.output_name or "_col")
 175                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
 176                select = select.replace(exp.alias_(select, alias, quoted=quoted))
 177
 178            taken_names.add(select.output_name)
 179            new_selects.append(select.args["alias"])
 180
 181        return (
 182            exp.select(*new_selects, copy=False)
 183            .from_(expression.subquery("_t", copy=False), copy=False)
 184            .where(exp.column(row_number_window_alias).eq(1), copy=False)
 185        )
 186
 187    return expression
 188
 189
 190def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 191    """
 192    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 193
 194    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 195    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 196
 197    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 198    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 199    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 200    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 201    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 202    corresponding expression to avoid creating invalid column references.
 203    """
 204    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 205        taken = set(expression.named_selects)
 206        for select in expression.selects:
 207            if not select.alias_or_name:
 208                alias = find_new_name(taken, "_c")
 209                select.replace(exp.alias_(select, alias))
 210                taken.add(alias)
 211
 212        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
 213            alias_or_name = select.alias_or_name
 214            identifier = select.args.get("alias") or select.this
 215            if isinstance(identifier, exp.Identifier):
 216                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
 217            return alias_or_name
 218
 219        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
 220        qualify_filters = expression.args["qualify"].pop().this
 221        expression_by_alias = {
 222            select.alias: select.this
 223            for select in expression.selects
 224            if isinstance(select, exp.Alias)
 225        }
 226
 227        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
 228        for select_candidate in list(qualify_filters.find_all(select_candidates)):
 229            if isinstance(select_candidate, exp.Window):
 230                if expression_by_alias:
 231                    for column in select_candidate.find_all(exp.Column):
 232                        expr = expression_by_alias.get(column.name)
 233                        if expr:
 234                            column.replace(expr)
 235
 236                alias = find_new_name(expression.named_selects, "_w")
 237                expression.select(exp.alias_(select_candidate, alias), copy=False)
 238                column = exp.column(alias)
 239
 240                if isinstance(select_candidate.parent, exp.Qualify):
 241                    qualify_filters = column
 242                else:
 243                    select_candidate.replace(column)
 244            elif select_candidate.name not in expression.named_selects:
 245                expression.select(select_candidate.copy(), copy=False)
 246
 247        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
 248            qualify_filters, copy=False
 249        )
 250
 251    return expression
 252
 253
 254def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
 255    """
 256    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
 257    other expressions. This transforms removes the precision from parameterized types in expressions.
 258    """
 259    for node in expression.find_all(exp.DataType):
 260        node.set(
 261            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
 262        )
 263
 264    return expression
 265
 266
 267def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
 268    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
 269    from sqlglot.optimizer.scope import find_all_in_scope
 270
 271    if isinstance(expression, exp.Select):
 272        unnest_aliases = {
 273            unnest.alias
 274            for unnest in find_all_in_scope(expression, exp.Unnest)
 275            if isinstance(unnest.parent, (exp.From, exp.Join))
 276        }
 277        if unnest_aliases:
 278            for column in expression.find_all(exp.Column):
 279                leftmost_part = column.parts[0]
 280                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
 281                    leftmost_part.pop()
 282
 283    return expression
 284
 285
 286def unnest_to_explode(
 287    expression: exp.Expression,
 288    unnest_using_arrays_zip: bool = True,
 289) -> exp.Expression:
 290    """Convert cross join unnest into lateral view explode."""
 291
 292    def _unnest_zip_exprs(
 293        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
 294    ) -> t.List[exp.Expression]:
 295        if has_multi_expr:
 296            if not unnest_using_arrays_zip:
 297                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
 298
 299            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
 300            zip_exprs: t.List[exp.Expression] = [
 301                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
 302            ]
 303            u.set("expressions", zip_exprs)
 304            return zip_exprs
 305        return unnest_exprs
 306
 307    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
 308        if u.args.get("offset"):
 309            return exp.Posexplode
 310        return exp.Inline if has_multi_expr else exp.Explode
 311
 312    if isinstance(expression, exp.Select):
 313        from_ = expression.args.get("from")
 314
 315        if from_ and isinstance(from_.this, exp.Unnest):
 316            unnest = from_.this
 317            alias = unnest.args.get("alias")
 318            exprs = unnest.expressions
 319            has_multi_expr = len(exprs) > 1
 320            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 321
 322            columns = alias.columns if alias else []
 323            offset = unnest.args.get("offset")
 324            if offset:
 325                columns.insert(
 326                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
 327                )
 328
 329            unnest.replace(
 330                exp.Table(
 331                    this=_udtf_type(unnest, has_multi_expr)(
 332                        this=this,
 333                        expressions=expressions,
 334                    ),
 335                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
 336                )
 337            )
 338
 339        joins = expression.args.get("joins") or []
 340        for join in list(joins):
 341            join_expr = join.this
 342
 343            is_lateral = isinstance(join_expr, exp.Lateral)
 344
 345            unnest = join_expr.this if is_lateral else join_expr
 346
 347            if isinstance(unnest, exp.Unnest):
 348                if is_lateral:
 349                    alias = join_expr.args.get("alias")
 350                else:
 351                    alias = unnest.args.get("alias")
 352                exprs = unnest.expressions
 353                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
 354                has_multi_expr = len(exprs) > 1
 355                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 356
 357                joins.remove(join)
 358
 359                alias_cols = alias.columns if alias else []
 360
 361                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
 362                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
 363                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
 364
 365                if not has_multi_expr and len(alias_cols) not in (1, 2):
 366                    raise UnsupportedError(
 367                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
 368                    )
 369
 370                offset = unnest.args.get("offset")
 371                if offset:
 372                    alias_cols.insert(
 373                        0,
 374                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
 375                    )
 376
 377                for e, column in zip(exprs, alias_cols):
 378                    expression.append(
 379                        "laterals",
 380                        exp.Lateral(
 381                            this=_udtf_type(unnest, has_multi_expr)(this=e),
 382                            view=True,
 383                            alias=exp.TableAlias(
 384                                this=alias.this,  # type: ignore
 385                                columns=alias_cols,
 386                            ),
 387                        ),
 388                    )
 389
 390    return expression
 391
 392
 393def explode_projection_to_unnest(
 394    index_offset: int = 0,
 395) -> t.Callable[[exp.Expression], exp.Expression]:
 396    """Convert explode/posexplode projections into unnests."""
 397
 398    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
 399        if isinstance(expression, exp.Select):
 400            from sqlglot.optimizer.scope import Scope
 401
 402            taken_select_names = set(expression.named_selects)
 403            taken_source_names = {name for name, _ in Scope(expression).references}
 404
 405            def new_name(names: t.Set[str], name: str) -> str:
 406                name = find_new_name(names, name)
 407                names.add(name)
 408                return name
 409
 410            arrays: t.List[exp.Condition] = []
 411            series_alias = new_name(taken_select_names, "pos")
 412            series = exp.alias_(
 413                exp.Unnest(
 414                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
 415                ),
 416                new_name(taken_source_names, "_u"),
 417                table=[series_alias],
 418            )
 419
 420            # we use list here because expression.selects is mutated inside the loop
 421            for select in list(expression.selects):
 422                explode = select.find(exp.Explode)
 423
 424                if explode:
 425                    pos_alias = ""
 426                    explode_alias = ""
 427
 428                    if isinstance(select, exp.Alias):
 429                        explode_alias = select.args["alias"]
 430                        alias = select
 431                    elif isinstance(select, exp.Aliases):
 432                        pos_alias = select.aliases[0]
 433                        explode_alias = select.aliases[1]
 434                        alias = select.replace(exp.alias_(select.this, "", copy=False))
 435                    else:
 436                        alias = select.replace(exp.alias_(select, ""))
 437                        explode = alias.find(exp.Explode)
 438                        assert explode
 439
 440                    is_posexplode = isinstance(explode, exp.Posexplode)
 441                    explode_arg = explode.this
 442
 443                    if isinstance(explode, exp.ExplodeOuter):
 444                        bracket = explode_arg[0]
 445                        bracket.set("safe", True)
 446                        bracket.set("offset", True)
 447                        explode_arg = exp.func(
 448                            "IF",
 449                            exp.func(
 450                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
 451                            ).eq(0),
 452                            exp.array(bracket, copy=False),
 453                            explode_arg,
 454                        )
 455
 456                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
 457                    if isinstance(explode_arg, exp.Column):
 458                        taken_select_names.add(explode_arg.output_name)
 459
 460                    unnest_source_alias = new_name(taken_source_names, "_u")
 461
 462                    if not explode_alias:
 463                        explode_alias = new_name(taken_select_names, "col")
 464
 465                        if is_posexplode:
 466                            pos_alias = new_name(taken_select_names, "pos")
 467
 468                    if not pos_alias:
 469                        pos_alias = new_name(taken_select_names, "pos")
 470
 471                    alias.set("alias", exp.to_identifier(explode_alias))
 472
 473                    series_table_alias = series.args["alias"].this
 474                    column = exp.If(
 475                        this=exp.column(series_alias, table=series_table_alias).eq(
 476                            exp.column(pos_alias, table=unnest_source_alias)
 477                        ),
 478                        true=exp.column(explode_alias, table=unnest_source_alias),
 479                    )
 480
 481                    explode.replace(column)
 482
 483                    if is_posexplode:
 484                        expressions = expression.expressions
 485                        expressions.insert(
 486                            expressions.index(alias) + 1,
 487                            exp.If(
 488                                this=exp.column(series_alias, table=series_table_alias).eq(
 489                                    exp.column(pos_alias, table=unnest_source_alias)
 490                                ),
 491                                true=exp.column(pos_alias, table=unnest_source_alias),
 492                            ).as_(pos_alias),
 493                        )
 494                        expression.set("expressions", expressions)
 495
 496                    if not arrays:
 497                        if expression.args.get("from"):
 498                            expression.join(series, copy=False, join_type="CROSS")
 499                        else:
 500                            expression.from_(series, copy=False)
 501
 502                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
 503                    arrays.append(size)
 504
 505                    # trino doesn't support left join unnest with on conditions
 506                    # if it did, this would be much simpler
 507                    expression.join(
 508                        exp.alias_(
 509                            exp.Unnest(
 510                                expressions=[explode_arg.copy()],
 511                                offset=exp.to_identifier(pos_alias),
 512                            ),
 513                            unnest_source_alias,
 514                            table=[explode_alias],
 515                        ),
 516                        join_type="CROSS",
 517                        copy=False,
 518                    )
 519
 520                    if index_offset != 1:
 521                        size = size - 1
 522
 523                    expression.where(
 524                        exp.column(series_alias, table=series_table_alias)
 525                        .eq(exp.column(pos_alias, table=unnest_source_alias))
 526                        .or_(
 527                            (exp.column(series_alias, table=series_table_alias) > size).and_(
 528                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
 529                            )
 530                        ),
 531                        copy=False,
 532                    )
 533
 534            if arrays:
 535                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
 536
 537                if index_offset != 1:
 538                    end = end - (1 - index_offset)
 539                series.expressions[0].set("end", end)
 540
 541        return expression
 542
 543    return _explode_projection_to_unnest
 544
 545
 546def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 547    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
 548    if (
 549        isinstance(expression, exp.PERCENTILES)
 550        and not isinstance(expression.parent, exp.WithinGroup)
 551        and expression.expression
 552    ):
 553        column = expression.this.pop()
 554        expression.set("this", expression.expression.pop())
 555        order = exp.Order(expressions=[exp.Ordered(this=column)])
 556        expression = exp.WithinGroup(this=expression, expression=order)
 557
 558    return expression
 559
 560
 561def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
 562    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
 563    if (
 564        isinstance(expression, exp.WithinGroup)
 565        and isinstance(expression.this, exp.PERCENTILES)
 566        and isinstance(expression.expression, exp.Order)
 567    ):
 568        quantile = expression.this.this
 569        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
 570        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
 571
 572    return expression
 573
 574
 575def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
 576    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
 577    if isinstance(expression, exp.With) and expression.recursive:
 578        next_name = name_sequence("_c_")
 579
 580        for cte in expression.expressions:
 581            if not cte.args["alias"].columns:
 582                query = cte.this
 583                if isinstance(query, exp.SetOperation):
 584                    query = query.this
 585
 586                cte.args["alias"].set(
 587                    "columns",
 588                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
 589                )
 590
 591    return expression
 592
 593
 594def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
 595    """Replace 'epoch' in casts by the equivalent date literal."""
 596    if (
 597        isinstance(expression, (exp.Cast, exp.TryCast))
 598        and expression.name.lower() == "epoch"
 599        and expression.to.this in exp.DataType.TEMPORAL_TYPES
 600    ):
 601        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
 602
 603    return expression
 604
 605
 606def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
 607    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
 608    if isinstance(expression, exp.Select):
 609        for join in expression.args.get("joins") or []:
 610            on = join.args.get("on")
 611            if on and join.kind in ("SEMI", "ANTI"):
 612                subquery = exp.select("1").from_(join.this).where(on)
 613                exists = exp.Exists(this=subquery)
 614                if join.kind == "ANTI":
 615                    exists = exists.not_(copy=False)
 616
 617                join.pop()
 618                expression.where(exists, copy=False)
 619
 620    return expression
 621
 622
 623def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
 624    """
 625    Converts a query with a FULL OUTER join to a union of identical queries that
 626    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
 627    for queries that have a single FULL OUTER join.
 628    """
 629    if isinstance(expression, exp.Select):
 630        full_outer_joins = [
 631            (index, join)
 632            for index, join in enumerate(expression.args.get("joins") or [])
 633            if join.side == "FULL"
 634        ]
 635
 636        if len(full_outer_joins) == 1:
 637            expression_copy = expression.copy()
 638            expression.set("limit", None)
 639            index, full_outer_join = full_outer_joins[0]
 640
 641            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
 642            join_conditions = full_outer_join.args.get("on") or exp.and_(
 643                *[
 644                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
 645                    for col in full_outer_join.args.get("using")
 646                ]
 647            )
 648
 649            full_outer_join.set("side", "left")
 650            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
 651            expression_copy.args["joins"][index].set("side", "right")
 652            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
 653            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
 654            expression.args.pop("order", None)  # remove order by from LEFT side
 655
 656            return exp.union(expression, expression_copy, copy=False, distinct=False)
 657
 658    return expression
 659
 660
 661def move_ctes_to_top_level(expression: E) -> E:
 662    """
 663    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
 664    defined at the top-level, so for example queries like:
 665
 666        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
 667
 668    are invalid in those dialects. This transformation can be used to ensure all CTEs are
 669    moved to the top level so that the final SQL code is valid from a syntax standpoint.
 670
 671    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
 672    """
 673    top_level_with = expression.args.get("with")
 674    for inner_with in expression.find_all(exp.With):
 675        if inner_with.parent is expression:
 676            continue
 677
 678        if not top_level_with:
 679            top_level_with = inner_with.pop()
 680            expression.set("with", top_level_with)
 681        else:
 682            if inner_with.recursive:
 683                top_level_with.set("recursive", True)
 684
 685            parent_cte = inner_with.find_ancestor(exp.CTE)
 686            inner_with.pop()
 687
 688            if parent_cte:
 689                i = top_level_with.expressions.index(parent_cte)
 690                top_level_with.expressions[i:i] = inner_with.expressions
 691                top_level_with.set("expressions", top_level_with.expressions)
 692            else:
 693                top_level_with.set(
 694                    "expressions", top_level_with.expressions + inner_with.expressions
 695                )
 696
 697    return expression
 698
 699
 700def ensure_bools(expression: exp.Expression) -> exp.Expression:
 701    """Converts numeric values used in conditions into explicit boolean expressions."""
 702    from sqlglot.optimizer.canonicalize import ensure_bools
 703
 704    def _ensure_bool(node: exp.Expression) -> None:
 705        if (
 706            node.is_number
 707            or (
 708                not isinstance(node, exp.SubqueryPredicate)
 709                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
 710            )
 711            or (isinstance(node, exp.Column) and not node.type)
 712        ):
 713            node.replace(node.neq(0))
 714
 715    for node in expression.walk():
 716        ensure_bools(node, _ensure_bool)
 717
 718    return expression
 719
 720
 721def unqualify_columns(expression: exp.Expression) -> exp.Expression:
 722    for column in expression.find_all(exp.Column):
 723        # We only wanna pop off the table, db, catalog args
 724        for part in column.parts[:-1]:
 725            part.pop()
 726
 727    return expression
 728
 729
 730def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
 731    assert isinstance(expression, exp.Create)
 732    for constraint in expression.find_all(exp.UniqueColumnConstraint):
 733        if constraint.parent:
 734            constraint.parent.pop()
 735
 736    return expression
 737
 738
 739def ctas_with_tmp_tables_to_create_tmp_view(
 740    expression: exp.Expression,
 741    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
 742) -> exp.Expression:
 743    assert isinstance(expression, exp.Create)
 744    properties = expression.args.get("properties")
 745    temporary = any(
 746        isinstance(prop, exp.TemporaryProperty)
 747        for prop in (properties.expressions if properties else [])
 748    )
 749
 750    # CTAS with temp tables map to CREATE TEMPORARY VIEW
 751    if expression.kind == "TABLE" and temporary:
 752        if expression.expression:
 753            return exp.Create(
 754                kind="TEMPORARY VIEW",
 755                this=expression.this,
 756                expression=expression.expression,
 757            )
 758        return tmp_storage_provider(expression)
 759
 760    return expression
 761
 762
 763def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
 764    """
 765    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
 766    PARTITIONED BY value is an array of column names, they are transformed into a schema.
 767    The corresponding columns are removed from the create statement.
 768    """
 769    assert isinstance(expression, exp.Create)
 770    has_schema = isinstance(expression.this, exp.Schema)
 771    is_partitionable = expression.kind in {"TABLE", "VIEW"}
 772
 773    if has_schema and is_partitionable:
 774        prop = expression.find(exp.PartitionedByProperty)
 775        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 776            schema = expression.this
 777            columns = {v.name.upper() for v in prop.this.expressions}
 778            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 779            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 780            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 781            expression.set("this", schema)
 782
 783    return expression
 784
 785
 786def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
 787    """
 788    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
 789
 790    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
 791    """
 792    assert isinstance(expression, exp.Create)
 793    prop = expression.find(exp.PartitionedByProperty)
 794    if (
 795        prop
 796        and prop.this
 797        and isinstance(prop.this, exp.Schema)
 798        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
 799    ):
 800        prop_this = exp.Tuple(
 801            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
 802        )
 803        schema = expression.this
 804        for e in prop.this.expressions:
 805            schema.append("expressions", e)
 806        prop.set("this", prop_this)
 807
 808    return expression
 809
 810
 811def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
 812    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
 813    if isinstance(expression, exp.Struct):
 814        expression.set(
 815            "expressions",
 816            [
 817                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
 818                for e in expression.expressions
 819            ],
 820        )
 821
 822    return expression
 823
 824
 825def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
 826    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
 827
 828    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
 829
 830    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
 831
 832    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
 833
 834    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
 835
 836    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
 837
 838    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
 839
 840    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
 841
 842    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
 843
 844    -- example with WHERE
 845    SELECT d.department_name, sum(e.salary) as total_salary
 846    FROM departments d, employees e
 847    WHERE e.department_id(+) = d.department_id
 848    group by department_name
 849
 850    -- example of left correlation in select
 851    SELECT d.department_name, (
 852        SELECT SUM(e.salary)
 853            FROM employees e
 854            WHERE e.department_id(+) = d.department_id) AS total_salary
 855    FROM departments d;
 856
 857    -- example of left correlation in from
 858    SELECT d.department_name, t.total_salary
 859    FROM departments d, (
 860            SELECT SUM(e.salary) AS total_salary
 861            FROM employees e
 862            WHERE e.department_id(+) = d.department_id
 863        ) t
 864    """
 865
 866    from sqlglot.optimizer.scope import traverse_scope
 867    from sqlglot.optimizer.normalize import normalize, normalized
 868    from collections import defaultdict
 869
 870    # we go in reverse to check the main query for left correlation
 871    for scope in reversed(traverse_scope(expression)):
 872        query = scope.expression
 873
 874        where = query.args.get("where")
 875        joins = query.args.get("joins", [])
 876
 877        # knockout: we do not support left correlation (see point 2)
 878        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
 879
 880        # nothing to do - we check it here after knockout above
 881        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
 882            continue
 883
 884        # make sure we have AND of ORs to have clear join terms
 885        where = normalize(where.this)
 886        assert normalized(where), "Cannot normalize JOIN predicates"
 887
 888        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
 889        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
 890            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
 891
 892            left_join_table = set(col.table for col in join_cols)
 893            if not left_join_table:
 894                continue
 895
 896            assert not (
 897                len(left_join_table) > 1
 898            ), "Cannot combine JOIN predicates from different tables"
 899
 900            for col in join_cols:
 901                col.set("join_mark", False)
 902
 903            joins_ons[left_join_table.pop()].append(cond)
 904
 905        old_joins = {join.alias_or_name: join for join in joins}
 906        new_joins = {}
 907        query_from = query.args["from"]
 908
 909        for table, predicates in joins_ons.items():
 910            join_what = old_joins.get(table, query_from).this.copy()
 911            new_joins[join_what.alias_or_name] = exp.Join(
 912                this=join_what, on=exp.and_(*predicates), kind="LEFT"
 913            )
 914
 915            for p in predicates:
 916                while isinstance(p.parent, exp.Paren):
 917                    p.parent.replace(p)
 918
 919                parent = p.parent
 920                p.pop()
 921                if isinstance(parent, exp.Binary):
 922                    parent.replace(parent.right if parent.left is None else parent.left)
 923                elif isinstance(parent, exp.Where):
 924                    parent.pop()
 925
 926        if query_from.alias_or_name in new_joins:
 927            only_old_joins = old_joins.keys() - new_joins.keys()
 928            assert (
 929                len(only_old_joins) >= 1
 930            ), "Cannot determine which table to use in the new FROM clause"
 931
 932            new_from_name = list(only_old_joins)[0]
 933            query.set("from", exp.From(this=old_joins[new_from_name].this))
 934
 935        if new_joins:
 936            for n, j in old_joins.items():  # preserve any other joins
 937                if n not in new_joins and n != query.args["from"].name:
 938                    if not j.kind:
 939                        j.set("kind", "CROSS")
 940                    new_joins[n] = j
 941            query.set("joins", list(new_joins.values()))
 942
 943    return expression
 944
 945
 946def any_to_exists(expression: exp.Expression) -> exp.Expression:
 947    """
 948    Transform ANY operator to Spark's EXISTS
 949
 950    For example,
 951        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
 952        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
 953
 954    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
 955    transformation
 956    """
 957    if isinstance(expression, exp.Select):
 958        for any_expr in expression.find_all(exp.Any):
 959            this = any_expr.this
 960            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
 961                continue
 962
 963            binop = any_expr.parent
 964            if isinstance(binop, exp.Binary):
 965                lambda_arg = exp.to_identifier("x")
 966                any_expr.replace(lambda_arg)
 967                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
 968                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
 969
 970    return expression
 971
 972
 973def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
 974    """Eliminates the `WINDOW` query clause by inling each named window."""
 975    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 976        from sqlglot.optimizer.scope import find_all_in_scope
 977
 978        windows = expression.args["windows"]
 979        expression.set("windows", None)
 980
 981        window_expression: t.Dict[str, exp.Expression] = {}
 982
 983        def _inline_inherited_window(window: exp.Expression) -> None:
 984            inherited_window = window_expression.get(window.alias.lower())
 985            if not inherited_window:
 986                return
 987
 988            window.set("alias", None)
 989            for key in ("partition_by", "order", "spec"):
 990                arg = inherited_window.args.get(key)
 991                if arg:
 992                    window.set(key, arg.copy())
 993
 994        for window in windows:
 995            _inline_inherited_window(window)
 996            window_expression[window.name.lower()] = window
 997
 998        for window in find_all_in_scope(expression, exp.Window):
 999            _inline_inherited_window(window)
1000
1001    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
136    """
137    Convert SELECT DISTINCT ON statements to a subquery with a window function.
138
139    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
140
141    Args:
142        expression: the expression that will be transformed.
143
144    Returns:
145        The transformed expression.
146    """
147    if (
148        isinstance(expression, exp.Select)
149        and expression.args.get("distinct")
150        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
151    ):
152        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
153
154        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
155        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
156
157        order = expression.args.get("order")
158        if order:
159            window.set("order", order.pop())
160        else:
161            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
162
163        window = exp.alias_(window, row_number_window_alias)
164        expression.select(window, copy=False)
165
166        # We add aliases to the projections so that we can safely reference them in the outer query
167        new_selects = []
168        taken_names = {row_number_window_alias}
169        for select in expression.selects[:-1]:
170            if select.is_star:
171                new_selects = [exp.Star()]
172                break
173
174            if not isinstance(select, exp.Alias):
175                alias = find_new_name(taken_names, select.output_name or "_col")
176                quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None
177                select = select.replace(exp.alias_(select, alias, quoted=quoted))
178
179            taken_names.add(select.output_name)
180            new_selects.append(select.args["alias"])
181
182        return (
183            exp.select(*new_selects, copy=False)
184            .from_(expression.subquery("_t", copy=False), copy=False)
185            .where(exp.column(row_number_window_alias).eq(1), copy=False)
186        )
187
188    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
191def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
192    """
193    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
194
195    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
196    https://docs.snowflake.com/en/sql-reference/constructs/qualify
197
198    Some dialects don't support window functions in the WHERE clause, so we need to include them as
199    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
200    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
201    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
202    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
203    corresponding expression to avoid creating invalid column references.
204    """
205    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
206        taken = set(expression.named_selects)
207        for select in expression.selects:
208            if not select.alias_or_name:
209                alias = find_new_name(taken, "_c")
210                select.replace(exp.alias_(select, alias))
211                taken.add(alias)
212
213        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
214            alias_or_name = select.alias_or_name
215            identifier = select.args.get("alias") or select.this
216            if isinstance(identifier, exp.Identifier):
217                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
218            return alias_or_name
219
220        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
221        qualify_filters = expression.args["qualify"].pop().this
222        expression_by_alias = {
223            select.alias: select.this
224            for select in expression.selects
225            if isinstance(select, exp.Alias)
226        }
227
228        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
229        for select_candidate in list(qualify_filters.find_all(select_candidates)):
230            if isinstance(select_candidate, exp.Window):
231                if expression_by_alias:
232                    for column in select_candidate.find_all(exp.Column):
233                        expr = expression_by_alias.get(column.name)
234                        if expr:
235                            column.replace(expr)
236
237                alias = find_new_name(expression.named_selects, "_w")
238                expression.select(exp.alias_(select_candidate, alias), copy=False)
239                column = exp.column(alias)
240
241                if isinstance(select_candidate.parent, exp.Qualify):
242                    qualify_filters = column
243                else:
244                    select_candidate.replace(column)
245            elif select_candidate.name not in expression.named_selects:
246                expression.select(select_candidate.copy(), copy=False)
247
248        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
249            qualify_filters, copy=False
250        )
251
252    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
255def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
256    """
257    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
258    other expressions. This transforms removes the precision from parameterized types in expressions.
259    """
260    for node in expression.find_all(exp.DataType):
261        node.set(
262            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
263        )
264
265    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
268def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
269    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
270    from sqlglot.optimizer.scope import find_all_in_scope
271
272    if isinstance(expression, exp.Select):
273        unnest_aliases = {
274            unnest.alias
275            for unnest in find_all_in_scope(expression, exp.Unnest)
276            if isinstance(unnest.parent, (exp.From, exp.Join))
277        }
278        if unnest_aliases:
279            for column in expression.find_all(exp.Column):
280                leftmost_part = column.parts[0]
281                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
282                    leftmost_part.pop()
283
284    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
287def unnest_to_explode(
288    expression: exp.Expression,
289    unnest_using_arrays_zip: bool = True,
290) -> exp.Expression:
291    """Convert cross join unnest into lateral view explode."""
292
293    def _unnest_zip_exprs(
294        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
295    ) -> t.List[exp.Expression]:
296        if has_multi_expr:
297            if not unnest_using_arrays_zip:
298                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
299
300            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
301            zip_exprs: t.List[exp.Expression] = [
302                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
303            ]
304            u.set("expressions", zip_exprs)
305            return zip_exprs
306        return unnest_exprs
307
308    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
309        if u.args.get("offset"):
310            return exp.Posexplode
311        return exp.Inline if has_multi_expr else exp.Explode
312
313    if isinstance(expression, exp.Select):
314        from_ = expression.args.get("from")
315
316        if from_ and isinstance(from_.this, exp.Unnest):
317            unnest = from_.this
318            alias = unnest.args.get("alias")
319            exprs = unnest.expressions
320            has_multi_expr = len(exprs) > 1
321            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
322
323            columns = alias.columns if alias else []
324            offset = unnest.args.get("offset")
325            if offset:
326                columns.insert(
327                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
328                )
329
330            unnest.replace(
331                exp.Table(
332                    this=_udtf_type(unnest, has_multi_expr)(
333                        this=this,
334                        expressions=expressions,
335                    ),
336                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
337                )
338            )
339
340        joins = expression.args.get("joins") or []
341        for join in list(joins):
342            join_expr = join.this
343
344            is_lateral = isinstance(join_expr, exp.Lateral)
345
346            unnest = join_expr.this if is_lateral else join_expr
347
348            if isinstance(unnest, exp.Unnest):
349                if is_lateral:
350                    alias = join_expr.args.get("alias")
351                else:
352                    alias = unnest.args.get("alias")
353                exprs = unnest.expressions
354                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
355                has_multi_expr = len(exprs) > 1
356                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
357
358                joins.remove(join)
359
360                alias_cols = alias.columns if alias else []
361
362                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
363                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
364                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
365
366                if not has_multi_expr and len(alias_cols) not in (1, 2):
367                    raise UnsupportedError(
368                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
369                    )
370
371                offset = unnest.args.get("offset")
372                if offset:
373                    alias_cols.insert(
374                        0,
375                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
376                    )
377
378                for e, column in zip(exprs, alias_cols):
379                    expression.append(
380                        "laterals",
381                        exp.Lateral(
382                            this=_udtf_type(unnest, has_multi_expr)(this=e),
383                            view=True,
384                            alias=exp.TableAlias(
385                                this=alias.this,  # type: ignore
386                                columns=alias_cols,
387                            ),
388                        ),
389                    )
390
391    return expression

Convert cross join unnest into lateral view explode.

def explode_projection_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
394def explode_projection_to_unnest(
395    index_offset: int = 0,
396) -> t.Callable[[exp.Expression], exp.Expression]:
397    """Convert explode/posexplode projections into unnests."""
398
399    def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression:
400        if isinstance(expression, exp.Select):
401            from sqlglot.optimizer.scope import Scope
402
403            taken_select_names = set(expression.named_selects)
404            taken_source_names = {name for name, _ in Scope(expression).references}
405
406            def new_name(names: t.Set[str], name: str) -> str:
407                name = find_new_name(names, name)
408                names.add(name)
409                return name
410
411            arrays: t.List[exp.Condition] = []
412            series_alias = new_name(taken_select_names, "pos")
413            series = exp.alias_(
414                exp.Unnest(
415                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
416                ),
417                new_name(taken_source_names, "_u"),
418                table=[series_alias],
419            )
420
421            # we use list here because expression.selects is mutated inside the loop
422            for select in list(expression.selects):
423                explode = select.find(exp.Explode)
424
425                if explode:
426                    pos_alias = ""
427                    explode_alias = ""
428
429                    if isinstance(select, exp.Alias):
430                        explode_alias = select.args["alias"]
431                        alias = select
432                    elif isinstance(select, exp.Aliases):
433                        pos_alias = select.aliases[0]
434                        explode_alias = select.aliases[1]
435                        alias = select.replace(exp.alias_(select.this, "", copy=False))
436                    else:
437                        alias = select.replace(exp.alias_(select, ""))
438                        explode = alias.find(exp.Explode)
439                        assert explode
440
441                    is_posexplode = isinstance(explode, exp.Posexplode)
442                    explode_arg = explode.this
443
444                    if isinstance(explode, exp.ExplodeOuter):
445                        bracket = explode_arg[0]
446                        bracket.set("safe", True)
447                        bracket.set("offset", True)
448                        explode_arg = exp.func(
449                            "IF",
450                            exp.func(
451                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
452                            ).eq(0),
453                            exp.array(bracket, copy=False),
454                            explode_arg,
455                        )
456
457                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
458                    if isinstance(explode_arg, exp.Column):
459                        taken_select_names.add(explode_arg.output_name)
460
461                    unnest_source_alias = new_name(taken_source_names, "_u")
462
463                    if not explode_alias:
464                        explode_alias = new_name(taken_select_names, "col")
465
466                        if is_posexplode:
467                            pos_alias = new_name(taken_select_names, "pos")
468
469                    if not pos_alias:
470                        pos_alias = new_name(taken_select_names, "pos")
471
472                    alias.set("alias", exp.to_identifier(explode_alias))
473
474                    series_table_alias = series.args["alias"].this
475                    column = exp.If(
476                        this=exp.column(series_alias, table=series_table_alias).eq(
477                            exp.column(pos_alias, table=unnest_source_alias)
478                        ),
479                        true=exp.column(explode_alias, table=unnest_source_alias),
480                    )
481
482                    explode.replace(column)
483
484                    if is_posexplode:
485                        expressions = expression.expressions
486                        expressions.insert(
487                            expressions.index(alias) + 1,
488                            exp.If(
489                                this=exp.column(series_alias, table=series_table_alias).eq(
490                                    exp.column(pos_alias, table=unnest_source_alias)
491                                ),
492                                true=exp.column(pos_alias, table=unnest_source_alias),
493                            ).as_(pos_alias),
494                        )
495                        expression.set("expressions", expressions)
496
497                    if not arrays:
498                        if expression.args.get("from"):
499                            expression.join(series, copy=False, join_type="CROSS")
500                        else:
501                            expression.from_(series, copy=False)
502
503                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
504                    arrays.append(size)
505
506                    # trino doesn't support left join unnest with on conditions
507                    # if it did, this would be much simpler
508                    expression.join(
509                        exp.alias_(
510                            exp.Unnest(
511                                expressions=[explode_arg.copy()],
512                                offset=exp.to_identifier(pos_alias),
513                            ),
514                            unnest_source_alias,
515                            table=[explode_alias],
516                        ),
517                        join_type="CROSS",
518                        copy=False,
519                    )
520
521                    if index_offset != 1:
522                        size = size - 1
523
524                    expression.where(
525                        exp.column(series_alias, table=series_table_alias)
526                        .eq(exp.column(pos_alias, table=unnest_source_alias))
527                        .or_(
528                            (exp.column(series_alias, table=series_table_alias) > size).and_(
529                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
530                            )
531                        ),
532                        copy=False,
533                    )
534
535            if arrays:
536                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
537
538                if index_offset != 1:
539                    end = end - (1 - index_offset)
540                series.expressions[0].set("end", end)
541
542        return expression
543
544    return _explode_projection_to_unnest

Convert explode/posexplode projections into unnests.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
547def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
548    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
549    if (
550        isinstance(expression, exp.PERCENTILES)
551        and not isinstance(expression.parent, exp.WithinGroup)
552        and expression.expression
553    ):
554        column = expression.this.pop()
555        expression.set("this", expression.expression.pop())
556        order = exp.Order(expressions=[exp.Ordered(this=column)])
557        expression = exp.WithinGroup(this=expression, expression=order)
558
559    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
562def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
563    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
564    if (
565        isinstance(expression, exp.WithinGroup)
566        and isinstance(expression.this, exp.PERCENTILES)
567        and isinstance(expression.expression, exp.Order)
568    ):
569        quantile = expression.this.this
570        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
571        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
572
573    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
576def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
577    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
578    if isinstance(expression, exp.With) and expression.recursive:
579        next_name = name_sequence("_c_")
580
581        for cte in expression.expressions:
582            if not cte.args["alias"].columns:
583                query = cte.this
584                if isinstance(query, exp.SetOperation):
585                    query = query.this
586
587                cte.args["alias"].set(
588                    "columns",
589                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
590                )
591
592    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
595def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
596    """Replace 'epoch' in casts by the equivalent date literal."""
597    if (
598        isinstance(expression, (exp.Cast, exp.TryCast))
599        and expression.name.lower() == "epoch"
600        and expression.to.this in exp.DataType.TEMPORAL_TYPES
601    ):
602        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
603
604    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
607def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
608    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
609    if isinstance(expression, exp.Select):
610        for join in expression.args.get("joins") or []:
611            on = join.args.get("on")
612            if on and join.kind in ("SEMI", "ANTI"):
613                subquery = exp.select("1").from_(join.this).where(on)
614                exists = exp.Exists(this=subquery)
615                if join.kind == "ANTI":
616                    exists = exists.not_(copy=False)
617
618                join.pop()
619                expression.where(exists, copy=False)
620
621    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
624def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
625    """
626    Converts a query with a FULL OUTER join to a union of identical queries that
627    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
628    for queries that have a single FULL OUTER join.
629    """
630    if isinstance(expression, exp.Select):
631        full_outer_joins = [
632            (index, join)
633            for index, join in enumerate(expression.args.get("joins") or [])
634            if join.side == "FULL"
635        ]
636
637        if len(full_outer_joins) == 1:
638            expression_copy = expression.copy()
639            expression.set("limit", None)
640            index, full_outer_join = full_outer_joins[0]
641
642            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
643            join_conditions = full_outer_join.args.get("on") or exp.and_(
644                *[
645                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
646                    for col in full_outer_join.args.get("using")
647                ]
648            )
649
650            full_outer_join.set("side", "left")
651            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
652            expression_copy.args["joins"][index].set("side", "right")
653            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
654            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
655            expression.args.pop("order", None)  # remove order by from LEFT side
656
657            return exp.union(expression, expression_copy, copy=False, distinct=False)
658
659    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
662def move_ctes_to_top_level(expression: E) -> E:
663    """
664    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
665    defined at the top-level, so for example queries like:
666
667        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
668
669    are invalid in those dialects. This transformation can be used to ensure all CTEs are
670    moved to the top level so that the final SQL code is valid from a syntax standpoint.
671
672    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
673    """
674    top_level_with = expression.args.get("with")
675    for inner_with in expression.find_all(exp.With):
676        if inner_with.parent is expression:
677            continue
678
679        if not top_level_with:
680            top_level_with = inner_with.pop()
681            expression.set("with", top_level_with)
682        else:
683            if inner_with.recursive:
684                top_level_with.set("recursive", True)
685
686            parent_cte = inner_with.find_ancestor(exp.CTE)
687            inner_with.pop()
688
689            if parent_cte:
690                i = top_level_with.expressions.index(parent_cte)
691                top_level_with.expressions[i:i] = inner_with.expressions
692                top_level_with.set("expressions", top_level_with.expressions)
693            else:
694                top_level_with.set(
695                    "expressions", top_level_with.expressions + inner_with.expressions
696                )
697
698    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
701def ensure_bools(expression: exp.Expression) -> exp.Expression:
702    """Converts numeric values used in conditions into explicit boolean expressions."""
703    from sqlglot.optimizer.canonicalize import ensure_bools
704
705    def _ensure_bool(node: exp.Expression) -> None:
706        if (
707            node.is_number
708            or (
709                not isinstance(node, exp.SubqueryPredicate)
710                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
711            )
712            or (isinstance(node, exp.Column) and not node.type)
713        ):
714            node.replace(node.neq(0))
715
716    for node in expression.walk():
717        ensure_bools(node, _ensure_bool)
718
719    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
722def unqualify_columns(expression: exp.Expression) -> exp.Expression:
723    for column in expression.find_all(exp.Column):
724        # We only wanna pop off the table, db, catalog args
725        for part in column.parts[:-1]:
726            part.pop()
727
728    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
731def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
732    assert isinstance(expression, exp.Create)
733    for constraint in expression.find_all(exp.UniqueColumnConstraint):
734        if constraint.parent:
735            constraint.parent.pop()
736
737    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
740def ctas_with_tmp_tables_to_create_tmp_view(
741    expression: exp.Expression,
742    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
743) -> exp.Expression:
744    assert isinstance(expression, exp.Create)
745    properties = expression.args.get("properties")
746    temporary = any(
747        isinstance(prop, exp.TemporaryProperty)
748        for prop in (properties.expressions if properties else [])
749    )
750
751    # CTAS with temp tables map to CREATE TEMPORARY VIEW
752    if expression.kind == "TABLE" and temporary:
753        if expression.expression:
754            return exp.Create(
755                kind="TEMPORARY VIEW",
756                this=expression.this,
757                expression=expression.expression,
758            )
759        return tmp_storage_provider(expression)
760
761    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
764def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
765    """
766    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
767    PARTITIONED BY value is an array of column names, they are transformed into a schema.
768    The corresponding columns are removed from the create statement.
769    """
770    assert isinstance(expression, exp.Create)
771    has_schema = isinstance(expression.this, exp.Schema)
772    is_partitionable = expression.kind in {"TABLE", "VIEW"}
773
774    if has_schema and is_partitionable:
775        prop = expression.find(exp.PartitionedByProperty)
776        if prop and prop.this and not isinstance(prop.this, exp.Schema):
777            schema = expression.this
778            columns = {v.name.upper() for v in prop.this.expressions}
779            partitions = [col for col in schema.expressions if col.name.upper() in columns]
780            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
781            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
782            expression.set("this", schema)
783
784    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
787def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
788    """
789    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
790
791    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
792    """
793    assert isinstance(expression, exp.Create)
794    prop = expression.find(exp.PartitionedByProperty)
795    if (
796        prop
797        and prop.this
798        and isinstance(prop.this, exp.Schema)
799        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
800    ):
801        prop_this = exp.Tuple(
802            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
803        )
804        schema = expression.this
805        for e in prop.this.expressions:
806            schema.append("expressions", e)
807        prop.set("this", prop_this)
808
809    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
812def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
813    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
814    if isinstance(expression, exp.Struct):
815        expression.set(
816            "expressions",
817            [
818                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
819                for e in expression.expressions
820            ],
821        )
822
823    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
826def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
827    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
828
829    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
830
831    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
832
833    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
834
835    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
836
837    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
838
839    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
840
841    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
842
843    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
844
845    -- example with WHERE
846    SELECT d.department_name, sum(e.salary) as total_salary
847    FROM departments d, employees e
848    WHERE e.department_id(+) = d.department_id
849    group by department_name
850
851    -- example of left correlation in select
852    SELECT d.department_name, (
853        SELECT SUM(e.salary)
854            FROM employees e
855            WHERE e.department_id(+) = d.department_id) AS total_salary
856    FROM departments d;
857
858    -- example of left correlation in from
859    SELECT d.department_name, t.total_salary
860    FROM departments d, (
861            SELECT SUM(e.salary) AS total_salary
862            FROM employees e
863            WHERE e.department_id(+) = d.department_id
864        ) t
865    """
866
867    from sqlglot.optimizer.scope import traverse_scope
868    from sqlglot.optimizer.normalize import normalize, normalized
869    from collections import defaultdict
870
871    # we go in reverse to check the main query for left correlation
872    for scope in reversed(traverse_scope(expression)):
873        query = scope.expression
874
875        where = query.args.get("where")
876        joins = query.args.get("joins", [])
877
878        # knockout: we do not support left correlation (see point 2)
879        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
880
881        # nothing to do - we check it here after knockout above
882        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
883            continue
884
885        # make sure we have AND of ORs to have clear join terms
886        where = normalize(where.this)
887        assert normalized(where), "Cannot normalize JOIN predicates"
888
889        joins_ons = defaultdict(list)  # dict of {name: list of join AND conditions}
890        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
891            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
892
893            left_join_table = set(col.table for col in join_cols)
894            if not left_join_table:
895                continue
896
897            assert not (
898                len(left_join_table) > 1
899            ), "Cannot combine JOIN predicates from different tables"
900
901            for col in join_cols:
902                col.set("join_mark", False)
903
904            joins_ons[left_join_table.pop()].append(cond)
905
906        old_joins = {join.alias_or_name: join for join in joins}
907        new_joins = {}
908        query_from = query.args["from"]
909
910        for table, predicates in joins_ons.items():
911            join_what = old_joins.get(table, query_from).this.copy()
912            new_joins[join_what.alias_or_name] = exp.Join(
913                this=join_what, on=exp.and_(*predicates), kind="LEFT"
914            )
915
916            for p in predicates:
917                while isinstance(p.parent, exp.Paren):
918                    p.parent.replace(p)
919
920                parent = p.parent
921                p.pop()
922                if isinstance(parent, exp.Binary):
923                    parent.replace(parent.right if parent.left is None else parent.left)
924                elif isinstance(parent, exp.Where):
925                    parent.pop()
926
927        if query_from.alias_or_name in new_joins:
928            only_old_joins = old_joins.keys() - new_joins.keys()
929            assert (
930                len(only_old_joins) >= 1
931            ), "Cannot determine which table to use in the new FROM clause"
932
933            new_from_name = list(only_old_joins)[0]
934            query.set("from", exp.From(this=old_joins[new_from_name].this))
935
936        if new_joins:
937            for n, j in old_joins.items():  # preserve any other joins
938                if n not in new_joins and n != query.args["from"].name:
939                    if not j.kind:
940                        j.set("kind", "CROSS")
941                    new_joins[n] = j
942            query.set("joins", list(new_joins.values()))
943
944    return expression

https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178

  1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.

  2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.

The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.

You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.

The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.

A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.

A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.

A WHERE condition cannot compare any column marked with the (+) operator with a subquery.

-- example with WHERE SELECT d.department_name, sum(e.salary) as total_salary FROM departments d, employees e WHERE e.department_id(+) = d.department_id group by department_name

-- example of left correlation in select SELECT d.department_name, ( SELECT SUM(e.salary) FROM employees e WHERE e.department_id(+) = d.department_id) AS total_salary FROM departments d;

-- example of left correlation in from SELECT d.department_name, t.total_salary FROM departments d, ( SELECT SUM(e.salary) AS total_salary FROM employees e WHERE e.department_id(+) = d.department_id ) t

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
947def any_to_exists(expression: exp.Expression) -> exp.Expression:
948    """
949    Transform ANY operator to Spark's EXISTS
950
951    For example,
952        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
953        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
954
955    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
956    transformation
957    """
958    if isinstance(expression, exp.Select):
959        for any_expr in expression.find_all(exp.Any):
960            this = any_expr.this
961            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
962                continue
963
964            binop = any_expr.parent
965            if isinstance(binop, exp.Binary):
966                lambda_arg = exp.to_identifier("x")
967                any_expr.replace(lambda_arg)
968                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
969                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
970
971    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation

def eliminate_window_clause( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 974def eliminate_window_clause(expression: exp.Expression) -> exp.Expression:
 975    """Eliminates the `WINDOW` query clause by inling each named window."""
 976    if isinstance(expression, exp.Select) and expression.args.get("windows"):
 977        from sqlglot.optimizer.scope import find_all_in_scope
 978
 979        windows = expression.args["windows"]
 980        expression.set("windows", None)
 981
 982        window_expression: t.Dict[str, exp.Expression] = {}
 983
 984        def _inline_inherited_window(window: exp.Expression) -> None:
 985            inherited_window = window_expression.get(window.alias.lower())
 986            if not inherited_window:
 987                return
 988
 989            window.set("alias", None)
 990            for key in ("partition_by", "order", "spec"):
 991                arg = inherited_window.args.get(key)
 992                if arg:
 993                    window.set(key, arg.copy())
 994
 995        for window in windows:
 996            _inline_inherited_window(window)
 997            window_expression[window.name.lower()] = window
 998
 999        for window in find_all_in_scope(expression, exp.Window):
1000            _inline_inherited_window(window)
1001
1002    return expression

Eliminates the WINDOW query clause by inling each named window.