sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot._typing import E 12 from sqlglot.generator import Generator 13 14 15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql 62 63 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression 119 120 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression 132 133 134def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 135 """ 136 Convert SELECT DISTINCT ON statements to a subquery with a window function. 137 138 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 139 140 Args: 141 expression: the expression that will be transformed. 142 143 Returns: 144 The transformed expression. 145 """ 146 if ( 147 isinstance(expression, exp.Select) 148 and expression.args.get("distinct") 149 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 150 ): 151 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 152 153 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 154 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 155 156 order = expression.args.get("order") 157 if order: 158 window.set("order", order.pop()) 159 else: 160 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 161 162 window = exp.alias_(window, row_number_window_alias) 163 expression.select(window, copy=False) 164 165 # We add aliases to the projections so that we can safely reference them in the outer query 166 new_selects = [] 167 taken_names = {row_number_window_alias} 168 for select in expression.selects[:-1]: 169 if select.is_star: 170 new_selects = [exp.Star()] 171 break 172 173 if not isinstance(select, exp.Alias): 174 alias = find_new_name(taken_names, select.output_name or "_col") 175 quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None 176 select = select.replace(exp.alias_(select, alias, quoted=quoted)) 177 178 taken_names.add(select.output_name) 179 new_selects.append(select.args["alias"]) 180 181 return ( 182 exp.select(*new_selects, copy=False) 183 .from_(expression.subquery("_t", copy=False), copy=False) 184 .where(exp.column(row_number_window_alias).eq(1), copy=False) 185 ) 186 187 return expression 188 189 190def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 191 """ 192 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 193 194 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 195 https://docs.snowflake.com/en/sql-reference/constructs/qualify 196 197 Some dialects don't support window functions in the WHERE clause, so we need to include them as 198 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 199 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 200 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 201 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 202 corresponding expression to avoid creating invalid column references. 203 """ 204 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 205 taken = set(expression.named_selects) 206 for select in expression.selects: 207 if not select.alias_or_name: 208 alias = find_new_name(taken, "_c") 209 select.replace(exp.alias_(select, alias)) 210 taken.add(alias) 211 212 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 213 alias_or_name = select.alias_or_name 214 identifier = select.args.get("alias") or select.this 215 if isinstance(identifier, exp.Identifier): 216 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 217 return alias_or_name 218 219 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 220 qualify_filters = expression.args["qualify"].pop().this 221 expression_by_alias = { 222 select.alias: select.this 223 for select in expression.selects 224 if isinstance(select, exp.Alias) 225 } 226 227 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 228 for select_candidate in list(qualify_filters.find_all(select_candidates)): 229 if isinstance(select_candidate, exp.Window): 230 if expression_by_alias: 231 for column in select_candidate.find_all(exp.Column): 232 expr = expression_by_alias.get(column.name) 233 if expr: 234 column.replace(expr) 235 236 alias = find_new_name(expression.named_selects, "_w") 237 expression.select(exp.alias_(select_candidate, alias), copy=False) 238 column = exp.column(alias) 239 240 if isinstance(select_candidate.parent, exp.Qualify): 241 qualify_filters = column 242 else: 243 select_candidate.replace(column) 244 elif select_candidate.name not in expression.named_selects: 245 expression.select(select_candidate.copy(), copy=False) 246 247 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 248 qualify_filters, copy=False 249 ) 250 251 return expression 252 253 254def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 255 """ 256 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 257 other expressions. This transforms removes the precision from parameterized types in expressions. 258 """ 259 for node in expression.find_all(exp.DataType): 260 node.set( 261 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 262 ) 263 264 return expression 265 266 267def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 268 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 269 from sqlglot.optimizer.scope import find_all_in_scope 270 271 if isinstance(expression, exp.Select): 272 unnest_aliases = { 273 unnest.alias 274 for unnest in find_all_in_scope(expression, exp.Unnest) 275 if isinstance(unnest.parent, (exp.From, exp.Join)) 276 } 277 if unnest_aliases: 278 for column in expression.find_all(exp.Column): 279 leftmost_part = column.parts[0] 280 if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases: 281 leftmost_part.pop() 282 283 return expression 284 285 286def unnest_to_explode( 287 expression: exp.Expression, 288 unnest_using_arrays_zip: bool = True, 289) -> exp.Expression: 290 """Convert cross join unnest into lateral view explode.""" 291 292 def _unnest_zip_exprs( 293 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 294 ) -> t.List[exp.Expression]: 295 if has_multi_expr: 296 if not unnest_using_arrays_zip: 297 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 298 299 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 300 zip_exprs: t.List[exp.Expression] = [ 301 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 302 ] 303 u.set("expressions", zip_exprs) 304 return zip_exprs 305 return unnest_exprs 306 307 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 308 if u.args.get("offset"): 309 return exp.Posexplode 310 return exp.Inline if has_multi_expr else exp.Explode 311 312 if isinstance(expression, exp.Select): 313 from_ = expression.args.get("from") 314 315 if from_ and isinstance(from_.this, exp.Unnest): 316 unnest = from_.this 317 alias = unnest.args.get("alias") 318 exprs = unnest.expressions 319 has_multi_expr = len(exprs) > 1 320 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 321 322 columns = alias.columns if alias else [] 323 offset = unnest.args.get("offset") 324 if offset: 325 columns.insert( 326 0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos") 327 ) 328 329 unnest.replace( 330 exp.Table( 331 this=_udtf_type(unnest, has_multi_expr)( 332 this=this, 333 expressions=expressions, 334 ), 335 alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None, 336 ) 337 ) 338 339 joins = expression.args.get("joins") or [] 340 for join in list(joins): 341 join_expr = join.this 342 343 is_lateral = isinstance(join_expr, exp.Lateral) 344 345 unnest = join_expr.this if is_lateral else join_expr 346 347 if isinstance(unnest, exp.Unnest): 348 if is_lateral: 349 alias = join_expr.args.get("alias") 350 else: 351 alias = unnest.args.get("alias") 352 exprs = unnest.expressions 353 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 354 has_multi_expr = len(exprs) > 1 355 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 356 357 joins.remove(join) 358 359 alias_cols = alias.columns if alias else [] 360 361 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 362 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 363 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 364 365 if not has_multi_expr and len(alias_cols) not in (1, 2): 366 raise UnsupportedError( 367 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 368 ) 369 370 offset = unnest.args.get("offset") 371 if offset: 372 alias_cols.insert( 373 0, 374 offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"), 375 ) 376 377 for e, column in zip(exprs, alias_cols): 378 expression.append( 379 "laterals", 380 exp.Lateral( 381 this=_udtf_type(unnest, has_multi_expr)(this=e), 382 view=True, 383 alias=exp.TableAlias( 384 this=alias.this, # type: ignore 385 columns=alias_cols, 386 ), 387 ), 388 ) 389 390 return expression 391 392 393def explode_projection_to_unnest( 394 index_offset: int = 0, 395) -> t.Callable[[exp.Expression], exp.Expression]: 396 """Convert explode/posexplode projections into unnests.""" 397 398 def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression: 399 if isinstance(expression, exp.Select): 400 from sqlglot.optimizer.scope import Scope 401 402 taken_select_names = set(expression.named_selects) 403 taken_source_names = {name for name, _ in Scope(expression).references} 404 405 def new_name(names: t.Set[str], name: str) -> str: 406 name = find_new_name(names, name) 407 names.add(name) 408 return name 409 410 arrays: t.List[exp.Condition] = [] 411 series_alias = new_name(taken_select_names, "pos") 412 series = exp.alias_( 413 exp.Unnest( 414 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 415 ), 416 new_name(taken_source_names, "_u"), 417 table=[series_alias], 418 ) 419 420 # we use list here because expression.selects is mutated inside the loop 421 for select in list(expression.selects): 422 explode = select.find(exp.Explode) 423 424 if explode: 425 pos_alias = "" 426 explode_alias = "" 427 428 if isinstance(select, exp.Alias): 429 explode_alias = select.args["alias"] 430 alias = select 431 elif isinstance(select, exp.Aliases): 432 pos_alias = select.aliases[0] 433 explode_alias = select.aliases[1] 434 alias = select.replace(exp.alias_(select.this, "", copy=False)) 435 else: 436 alias = select.replace(exp.alias_(select, "")) 437 explode = alias.find(exp.Explode) 438 assert explode 439 440 is_posexplode = isinstance(explode, exp.Posexplode) 441 explode_arg = explode.this 442 443 if isinstance(explode, exp.ExplodeOuter): 444 bracket = explode_arg[0] 445 bracket.set("safe", True) 446 bracket.set("offset", True) 447 explode_arg = exp.func( 448 "IF", 449 exp.func( 450 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 451 ).eq(0), 452 exp.array(bracket, copy=False), 453 explode_arg, 454 ) 455 456 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 457 if isinstance(explode_arg, exp.Column): 458 taken_select_names.add(explode_arg.output_name) 459 460 unnest_source_alias = new_name(taken_source_names, "_u") 461 462 if not explode_alias: 463 explode_alias = new_name(taken_select_names, "col") 464 465 if is_posexplode: 466 pos_alias = new_name(taken_select_names, "pos") 467 468 if not pos_alias: 469 pos_alias = new_name(taken_select_names, "pos") 470 471 alias.set("alias", exp.to_identifier(explode_alias)) 472 473 series_table_alias = series.args["alias"].this 474 column = exp.If( 475 this=exp.column(series_alias, table=series_table_alias).eq( 476 exp.column(pos_alias, table=unnest_source_alias) 477 ), 478 true=exp.column(explode_alias, table=unnest_source_alias), 479 ) 480 481 explode.replace(column) 482 483 if is_posexplode: 484 expressions = expression.expressions 485 expressions.insert( 486 expressions.index(alias) + 1, 487 exp.If( 488 this=exp.column(series_alias, table=series_table_alias).eq( 489 exp.column(pos_alias, table=unnest_source_alias) 490 ), 491 true=exp.column(pos_alias, table=unnest_source_alias), 492 ).as_(pos_alias), 493 ) 494 expression.set("expressions", expressions) 495 496 if not arrays: 497 if expression.args.get("from"): 498 expression.join(series, copy=False, join_type="CROSS") 499 else: 500 expression.from_(series, copy=False) 501 502 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 503 arrays.append(size) 504 505 # trino doesn't support left join unnest with on conditions 506 # if it did, this would be much simpler 507 expression.join( 508 exp.alias_( 509 exp.Unnest( 510 expressions=[explode_arg.copy()], 511 offset=exp.to_identifier(pos_alias), 512 ), 513 unnest_source_alias, 514 table=[explode_alias], 515 ), 516 join_type="CROSS", 517 copy=False, 518 ) 519 520 if index_offset != 1: 521 size = size - 1 522 523 expression.where( 524 exp.column(series_alias, table=series_table_alias) 525 .eq(exp.column(pos_alias, table=unnest_source_alias)) 526 .or_( 527 (exp.column(series_alias, table=series_table_alias) > size).and_( 528 exp.column(pos_alias, table=unnest_source_alias).eq(size) 529 ) 530 ), 531 copy=False, 532 ) 533 534 if arrays: 535 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 536 537 if index_offset != 1: 538 end = end - (1 - index_offset) 539 series.expressions[0].set("end", end) 540 541 return expression 542 543 return _explode_projection_to_unnest 544 545 546def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 547 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 548 if ( 549 isinstance(expression, exp.PERCENTILES) 550 and not isinstance(expression.parent, exp.WithinGroup) 551 and expression.expression 552 ): 553 column = expression.this.pop() 554 expression.set("this", expression.expression.pop()) 555 order = exp.Order(expressions=[exp.Ordered(this=column)]) 556 expression = exp.WithinGroup(this=expression, expression=order) 557 558 return expression 559 560 561def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 562 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 563 if ( 564 isinstance(expression, exp.WithinGroup) 565 and isinstance(expression.this, exp.PERCENTILES) 566 and isinstance(expression.expression, exp.Order) 567 ): 568 quantile = expression.this.this 569 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 570 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 571 572 return expression 573 574 575def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 576 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 577 if isinstance(expression, exp.With) and expression.recursive: 578 next_name = name_sequence("_c_") 579 580 for cte in expression.expressions: 581 if not cte.args["alias"].columns: 582 query = cte.this 583 if isinstance(query, exp.SetOperation): 584 query = query.this 585 586 cte.args["alias"].set( 587 "columns", 588 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 589 ) 590 591 return expression 592 593 594def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 595 """Replace 'epoch' in casts by the equivalent date literal.""" 596 if ( 597 isinstance(expression, (exp.Cast, exp.TryCast)) 598 and expression.name.lower() == "epoch" 599 and expression.to.this in exp.DataType.TEMPORAL_TYPES 600 ): 601 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 602 603 return expression 604 605 606def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 607 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 608 if isinstance(expression, exp.Select): 609 for join in expression.args.get("joins") or []: 610 on = join.args.get("on") 611 if on and join.kind in ("SEMI", "ANTI"): 612 subquery = exp.select("1").from_(join.this).where(on) 613 exists = exp.Exists(this=subquery) 614 if join.kind == "ANTI": 615 exists = exists.not_(copy=False) 616 617 join.pop() 618 expression.where(exists, copy=False) 619 620 return expression 621 622 623def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 624 """ 625 Converts a query with a FULL OUTER join to a union of identical queries that 626 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 627 for queries that have a single FULL OUTER join. 628 """ 629 if isinstance(expression, exp.Select): 630 full_outer_joins = [ 631 (index, join) 632 for index, join in enumerate(expression.args.get("joins") or []) 633 if join.side == "FULL" 634 ] 635 636 if len(full_outer_joins) == 1: 637 expression_copy = expression.copy() 638 expression.set("limit", None) 639 index, full_outer_join = full_outer_joins[0] 640 641 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 642 join_conditions = full_outer_join.args.get("on") or exp.and_( 643 *[ 644 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 645 for col in full_outer_join.args.get("using") 646 ] 647 ) 648 649 full_outer_join.set("side", "left") 650 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 651 expression_copy.args["joins"][index].set("side", "right") 652 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 653 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 654 expression.args.pop("order", None) # remove order by from LEFT side 655 656 return exp.union(expression, expression_copy, copy=False, distinct=False) 657 658 return expression 659 660 661def move_ctes_to_top_level(expression: E) -> E: 662 """ 663 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 664 defined at the top-level, so for example queries like: 665 666 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 667 668 are invalid in those dialects. This transformation can be used to ensure all CTEs are 669 moved to the top level so that the final SQL code is valid from a syntax standpoint. 670 671 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 672 """ 673 top_level_with = expression.args.get("with") 674 for inner_with in expression.find_all(exp.With): 675 if inner_with.parent is expression: 676 continue 677 678 if not top_level_with: 679 top_level_with = inner_with.pop() 680 expression.set("with", top_level_with) 681 else: 682 if inner_with.recursive: 683 top_level_with.set("recursive", True) 684 685 parent_cte = inner_with.find_ancestor(exp.CTE) 686 inner_with.pop() 687 688 if parent_cte: 689 i = top_level_with.expressions.index(parent_cte) 690 top_level_with.expressions[i:i] = inner_with.expressions 691 top_level_with.set("expressions", top_level_with.expressions) 692 else: 693 top_level_with.set( 694 "expressions", top_level_with.expressions + inner_with.expressions 695 ) 696 697 return expression 698 699 700def ensure_bools(expression: exp.Expression) -> exp.Expression: 701 """Converts numeric values used in conditions into explicit boolean expressions.""" 702 from sqlglot.optimizer.canonicalize import ensure_bools 703 704 def _ensure_bool(node: exp.Expression) -> None: 705 if ( 706 node.is_number 707 or ( 708 not isinstance(node, exp.SubqueryPredicate) 709 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 710 ) 711 or (isinstance(node, exp.Column) and not node.type) 712 ): 713 node.replace(node.neq(0)) 714 715 for node in expression.walk(): 716 ensure_bools(node, _ensure_bool) 717 718 return expression 719 720 721def unqualify_columns(expression: exp.Expression) -> exp.Expression: 722 for column in expression.find_all(exp.Column): 723 # We only wanna pop off the table, db, catalog args 724 for part in column.parts[:-1]: 725 part.pop() 726 727 return expression 728 729 730def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 731 assert isinstance(expression, exp.Create) 732 for constraint in expression.find_all(exp.UniqueColumnConstraint): 733 if constraint.parent: 734 constraint.parent.pop() 735 736 return expression 737 738 739def ctas_with_tmp_tables_to_create_tmp_view( 740 expression: exp.Expression, 741 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 742) -> exp.Expression: 743 assert isinstance(expression, exp.Create) 744 properties = expression.args.get("properties") 745 temporary = any( 746 isinstance(prop, exp.TemporaryProperty) 747 for prop in (properties.expressions if properties else []) 748 ) 749 750 # CTAS with temp tables map to CREATE TEMPORARY VIEW 751 if expression.kind == "TABLE" and temporary: 752 if expression.expression: 753 return exp.Create( 754 kind="TEMPORARY VIEW", 755 this=expression.this, 756 expression=expression.expression, 757 ) 758 return tmp_storage_provider(expression) 759 760 return expression 761 762 763def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 764 """ 765 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 766 PARTITIONED BY value is an array of column names, they are transformed into a schema. 767 The corresponding columns are removed from the create statement. 768 """ 769 assert isinstance(expression, exp.Create) 770 has_schema = isinstance(expression.this, exp.Schema) 771 is_partitionable = expression.kind in {"TABLE", "VIEW"} 772 773 if has_schema and is_partitionable: 774 prop = expression.find(exp.PartitionedByProperty) 775 if prop and prop.this and not isinstance(prop.this, exp.Schema): 776 schema = expression.this 777 columns = {v.name.upper() for v in prop.this.expressions} 778 partitions = [col for col in schema.expressions if col.name.upper() in columns] 779 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 780 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 781 expression.set("this", schema) 782 783 return expression 784 785 786def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 787 """ 788 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 789 790 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 791 """ 792 assert isinstance(expression, exp.Create) 793 prop = expression.find(exp.PartitionedByProperty) 794 if ( 795 prop 796 and prop.this 797 and isinstance(prop.this, exp.Schema) 798 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 799 ): 800 prop_this = exp.Tuple( 801 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 802 ) 803 schema = expression.this 804 for e in prop.this.expressions: 805 schema.append("expressions", e) 806 prop.set("this", prop_this) 807 808 return expression 809 810 811def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 812 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 813 if isinstance(expression, exp.Struct): 814 expression.set( 815 "expressions", 816 [ 817 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 818 for e in expression.expressions 819 ], 820 ) 821 822 return expression 823 824 825def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 826 """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178 827 828 1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax. 829 830 2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view. 831 832 The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query. 833 834 You cannot use the (+) operator to outer-join a table to itself, although self joins are valid. 835 836 The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator. 837 838 A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator. 839 840 A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression. 841 842 A WHERE condition cannot compare any column marked with the (+) operator with a subquery. 843 844 -- example with WHERE 845 SELECT d.department_name, sum(e.salary) as total_salary 846 FROM departments d, employees e 847 WHERE e.department_id(+) = d.department_id 848 group by department_name 849 850 -- example of left correlation in select 851 SELECT d.department_name, ( 852 SELECT SUM(e.salary) 853 FROM employees e 854 WHERE e.department_id(+) = d.department_id) AS total_salary 855 FROM departments d; 856 857 -- example of left correlation in from 858 SELECT d.department_name, t.total_salary 859 FROM departments d, ( 860 SELECT SUM(e.salary) AS total_salary 861 FROM employees e 862 WHERE e.department_id(+) = d.department_id 863 ) t 864 """ 865 866 from sqlglot.optimizer.scope import traverse_scope 867 from sqlglot.optimizer.normalize import normalize, normalized 868 from collections import defaultdict 869 870 # we go in reverse to check the main query for left correlation 871 for scope in reversed(traverse_scope(expression)): 872 query = scope.expression 873 874 where = query.args.get("where") 875 joins = query.args.get("joins", []) 876 877 # knockout: we do not support left correlation (see point 2) 878 assert not scope.is_correlated_subquery, "Correlated queries are not supported" 879 880 # nothing to do - we check it here after knockout above 881 if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)): 882 continue 883 884 # make sure we have AND of ORs to have clear join terms 885 where = normalize(where.this) 886 assert normalized(where), "Cannot normalize JOIN predicates" 887 888 joins_ons = defaultdict(list) # dict of {name: list of join AND conditions} 889 for cond in [where] if not isinstance(where, exp.And) else where.flatten(): 890 join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")] 891 892 left_join_table = set(col.table for col in join_cols) 893 if not left_join_table: 894 continue 895 896 assert not ( 897 len(left_join_table) > 1 898 ), "Cannot combine JOIN predicates from different tables" 899 900 for col in join_cols: 901 col.set("join_mark", False) 902 903 joins_ons[left_join_table.pop()].append(cond) 904 905 old_joins = {join.alias_or_name: join for join in joins} 906 new_joins = {} 907 query_from = query.args["from"] 908 909 for table, predicates in joins_ons.items(): 910 join_what = old_joins.get(table, query_from).this.copy() 911 new_joins[join_what.alias_or_name] = exp.Join( 912 this=join_what, on=exp.and_(*predicates), kind="LEFT" 913 ) 914 915 for p in predicates: 916 while isinstance(p.parent, exp.Paren): 917 p.parent.replace(p) 918 919 parent = p.parent 920 p.pop() 921 if isinstance(parent, exp.Binary): 922 parent.replace(parent.right if parent.left is None else parent.left) 923 elif isinstance(parent, exp.Where): 924 parent.pop() 925 926 if query_from.alias_or_name in new_joins: 927 only_old_joins = old_joins.keys() - new_joins.keys() 928 assert ( 929 len(only_old_joins) >= 1 930 ), "Cannot determine which table to use in the new FROM clause" 931 932 new_from_name = list(only_old_joins)[0] 933 query.set("from", exp.From(this=old_joins[new_from_name].this)) 934 935 if new_joins: 936 for n, j in old_joins.items(): # preserve any other joins 937 if n not in new_joins and n != query.args["from"].name: 938 if not j.kind: 939 j.set("kind", "CROSS") 940 new_joins[n] = j 941 query.set("joins", list(new_joins.values())) 942 943 return expression 944 945 946def any_to_exists(expression: exp.Expression) -> exp.Expression: 947 """ 948 Transform ANY operator to Spark's EXISTS 949 950 For example, 951 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 952 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 953 954 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 955 transformation 956 """ 957 if isinstance(expression, exp.Select): 958 for any_expr in expression.find_all(exp.Any): 959 this = any_expr.this 960 if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)): 961 continue 962 963 binop = any_expr.parent 964 if isinstance(binop, exp.Binary): 965 lambda_arg = exp.to_identifier("x") 966 any_expr.replace(lambda_arg) 967 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 968 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 969 970 return expression 971 972 973def eliminate_window_clause(expression: exp.Expression) -> exp.Expression: 974 """Eliminates the `WINDOW` query clause by inling each named window.""" 975 if isinstance(expression, exp.Select) and expression.args.get("windows"): 976 from sqlglot.optimizer.scope import find_all_in_scope 977 978 windows = expression.args["windows"] 979 expression.set("windows", None) 980 981 window_expression: t.Dict[str, exp.Expression] = {} 982 983 def _inline_inherited_window(window: exp.Expression) -> None: 984 inherited_window = window_expression.get(window.alias.lower()) 985 if not inherited_window: 986 return 987 988 window.set("alias", None) 989 for key in ("partition_by", "order", "spec"): 990 arg = inherited_window.args.get(key) 991 if arg: 992 window.set(key, arg.copy()) 993 994 for window in windows: 995 _inline_inherited_window(window) 996 window_expression[window.name.lower()] = window 997 998 for window in find_all_in_scope(expression, exp.Window): 999 _inline_inherited_window(window) 1000 1001 return expression
16def preprocess( 17 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 18) -> t.Callable[[Generator, exp.Expression], str]: 19 """ 20 Creates a new transform by chaining a sequence of transformations and converts the resulting 21 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 22 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 23 24 Args: 25 transforms: sequence of transform functions. These will be called in order. 26 27 Returns: 28 Function that can be used as a generator transform. 29 """ 30 31 def _to_sql(self, expression: exp.Expression) -> str: 32 expression_type = type(expression) 33 34 try: 35 expression = transforms[0](expression) 36 for transform in transforms[1:]: 37 expression = transform(expression) 38 except UnsupportedError as unsupported_error: 39 self.unsupported(str(unsupported_error)) 40 41 _sql_handler = getattr(self, expression.key + "_sql", None) 42 if _sql_handler: 43 return _sql_handler(expression) 44 45 transforms_handler = self.TRANSFORMS.get(type(expression)) 46 if transforms_handler: 47 if expression_type is type(expression): 48 if isinstance(expression, exp.Func): 49 return self.function_fallback_sql(expression) 50 51 # Ensures we don't enter an infinite loop. This can happen when the original expression 52 # has the same type as the final expression and there's no _sql method available for it, 53 # because then it'd re-enter _to_sql. 54 raise ValueError( 55 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 56 ) 57 58 return transforms_handler(self, expression) 59 60 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 61 62 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 66 if isinstance(expression, exp.Select): 67 count = 0 68 recursive_ctes = [] 69 70 for unnest in expression.find_all(exp.Unnest): 71 if ( 72 not isinstance(unnest.parent, (exp.From, exp.Join)) 73 or len(unnest.expressions) != 1 74 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 75 ): 76 continue 77 78 generate_date_array = unnest.expressions[0] 79 start = generate_date_array.args.get("start") 80 end = generate_date_array.args.get("end") 81 step = generate_date_array.args.get("step") 82 83 if not start or not end or not isinstance(step, exp.Interval): 84 continue 85 86 alias = unnest.args.get("alias") 87 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 88 89 start = exp.cast(start, "date") 90 date_add = exp.func( 91 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 92 ) 93 cast_date_add = exp.cast(date_add, "date") 94 95 cte_name = "_generated_dates" + (f"_{count}" if count else "") 96 97 base_query = exp.select(start.as_(column_name)) 98 recursive_query = ( 99 exp.select(cast_date_add) 100 .from_(cte_name) 101 .where(cast_date_add <= exp.cast(end, "date")) 102 ) 103 cte_query = base_query.union(recursive_query, distinct=False) 104 105 generate_dates_query = exp.select(column_name).from_(cte_name) 106 unnest.replace(generate_dates_query.subquery(cte_name)) 107 108 recursive_ctes.append( 109 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 110 ) 111 count += 1 112 113 if recursive_ctes: 114 with_expression = expression.args.get("with") or exp.With() 115 with_expression.set("recursive", True) 116 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 117 expression.set("with", with_expression) 118 119 return expression
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 123 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 124 this = expression.this 125 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 126 unnest = exp.Unnest(expressions=[this]) 127 if expression.alias: 128 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 129 130 return unnest 131 132 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
135def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 136 """ 137 Convert SELECT DISTINCT ON statements to a subquery with a window function. 138 139 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 140 141 Args: 142 expression: the expression that will be transformed. 143 144 Returns: 145 The transformed expression. 146 """ 147 if ( 148 isinstance(expression, exp.Select) 149 and expression.args.get("distinct") 150 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 151 ): 152 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 153 154 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 155 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 156 157 order = expression.args.get("order") 158 if order: 159 window.set("order", order.pop()) 160 else: 161 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 162 163 window = exp.alias_(window, row_number_window_alias) 164 expression.select(window, copy=False) 165 166 # We add aliases to the projections so that we can safely reference them in the outer query 167 new_selects = [] 168 taken_names = {row_number_window_alias} 169 for select in expression.selects[:-1]: 170 if select.is_star: 171 new_selects = [exp.Star()] 172 break 173 174 if not isinstance(select, exp.Alias): 175 alias = find_new_name(taken_names, select.output_name or "_col") 176 quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None 177 select = select.replace(exp.alias_(select, alias, quoted=quoted)) 178 179 taken_names.add(select.output_name) 180 new_selects.append(select.args["alias"]) 181 182 return ( 183 exp.select(*new_selects, copy=False) 184 .from_(expression.subquery("_t", copy=False), copy=False) 185 .where(exp.column(row_number_window_alias).eq(1), copy=False) 186 ) 187 188 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
191def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 192 """ 193 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 194 195 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 196 https://docs.snowflake.com/en/sql-reference/constructs/qualify 197 198 Some dialects don't support window functions in the WHERE clause, so we need to include them as 199 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 200 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 201 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 202 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 203 corresponding expression to avoid creating invalid column references. 204 """ 205 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 206 taken = set(expression.named_selects) 207 for select in expression.selects: 208 if not select.alias_or_name: 209 alias = find_new_name(taken, "_c") 210 select.replace(exp.alias_(select, alias)) 211 taken.add(alias) 212 213 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 214 alias_or_name = select.alias_or_name 215 identifier = select.args.get("alias") or select.this 216 if isinstance(identifier, exp.Identifier): 217 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 218 return alias_or_name 219 220 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 221 qualify_filters = expression.args["qualify"].pop().this 222 expression_by_alias = { 223 select.alias: select.this 224 for select in expression.selects 225 if isinstance(select, exp.Alias) 226 } 227 228 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 229 for select_candidate in list(qualify_filters.find_all(select_candidates)): 230 if isinstance(select_candidate, exp.Window): 231 if expression_by_alias: 232 for column in select_candidate.find_all(exp.Column): 233 expr = expression_by_alias.get(column.name) 234 if expr: 235 column.replace(expr) 236 237 alias = find_new_name(expression.named_selects, "_w") 238 expression.select(exp.alias_(select_candidate, alias), copy=False) 239 column = exp.column(alias) 240 241 if isinstance(select_candidate.parent, exp.Qualify): 242 qualify_filters = column 243 else: 244 select_candidate.replace(column) 245 elif select_candidate.name not in expression.named_selects: 246 expression.select(select_candidate.copy(), copy=False) 247 248 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 249 qualify_filters, copy=False 250 ) 251 252 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
255def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 256 """ 257 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 258 other expressions. This transforms removes the precision from parameterized types in expressions. 259 """ 260 for node in expression.find_all(exp.DataType): 261 node.set( 262 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 263 ) 264 265 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
268def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 269 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 270 from sqlglot.optimizer.scope import find_all_in_scope 271 272 if isinstance(expression, exp.Select): 273 unnest_aliases = { 274 unnest.alias 275 for unnest in find_all_in_scope(expression, exp.Unnest) 276 if isinstance(unnest.parent, (exp.From, exp.Join)) 277 } 278 if unnest_aliases: 279 for column in expression.find_all(exp.Column): 280 leftmost_part = column.parts[0] 281 if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases: 282 leftmost_part.pop() 283 284 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
287def unnest_to_explode( 288 expression: exp.Expression, 289 unnest_using_arrays_zip: bool = True, 290) -> exp.Expression: 291 """Convert cross join unnest into lateral view explode.""" 292 293 def _unnest_zip_exprs( 294 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 295 ) -> t.List[exp.Expression]: 296 if has_multi_expr: 297 if not unnest_using_arrays_zip: 298 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 299 300 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 301 zip_exprs: t.List[exp.Expression] = [ 302 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 303 ] 304 u.set("expressions", zip_exprs) 305 return zip_exprs 306 return unnest_exprs 307 308 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 309 if u.args.get("offset"): 310 return exp.Posexplode 311 return exp.Inline if has_multi_expr else exp.Explode 312 313 if isinstance(expression, exp.Select): 314 from_ = expression.args.get("from") 315 316 if from_ and isinstance(from_.this, exp.Unnest): 317 unnest = from_.this 318 alias = unnest.args.get("alias") 319 exprs = unnest.expressions 320 has_multi_expr = len(exprs) > 1 321 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 322 323 columns = alias.columns if alias else [] 324 offset = unnest.args.get("offset") 325 if offset: 326 columns.insert( 327 0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos") 328 ) 329 330 unnest.replace( 331 exp.Table( 332 this=_udtf_type(unnest, has_multi_expr)( 333 this=this, 334 expressions=expressions, 335 ), 336 alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None, 337 ) 338 ) 339 340 joins = expression.args.get("joins") or [] 341 for join in list(joins): 342 join_expr = join.this 343 344 is_lateral = isinstance(join_expr, exp.Lateral) 345 346 unnest = join_expr.this if is_lateral else join_expr 347 348 if isinstance(unnest, exp.Unnest): 349 if is_lateral: 350 alias = join_expr.args.get("alias") 351 else: 352 alias = unnest.args.get("alias") 353 exprs = unnest.expressions 354 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 355 has_multi_expr = len(exprs) > 1 356 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 357 358 joins.remove(join) 359 360 alias_cols = alias.columns if alias else [] 361 362 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 363 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 364 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 365 366 if not has_multi_expr and len(alias_cols) not in (1, 2): 367 raise UnsupportedError( 368 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 369 ) 370 371 offset = unnest.args.get("offset") 372 if offset: 373 alias_cols.insert( 374 0, 375 offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"), 376 ) 377 378 for e, column in zip(exprs, alias_cols): 379 expression.append( 380 "laterals", 381 exp.Lateral( 382 this=_udtf_type(unnest, has_multi_expr)(this=e), 383 view=True, 384 alias=exp.TableAlias( 385 this=alias.this, # type: ignore 386 columns=alias_cols, 387 ), 388 ), 389 ) 390 391 return expression
Convert cross join unnest into lateral view explode.
394def explode_projection_to_unnest( 395 index_offset: int = 0, 396) -> t.Callable[[exp.Expression], exp.Expression]: 397 """Convert explode/posexplode projections into unnests.""" 398 399 def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression: 400 if isinstance(expression, exp.Select): 401 from sqlglot.optimizer.scope import Scope 402 403 taken_select_names = set(expression.named_selects) 404 taken_source_names = {name for name, _ in Scope(expression).references} 405 406 def new_name(names: t.Set[str], name: str) -> str: 407 name = find_new_name(names, name) 408 names.add(name) 409 return name 410 411 arrays: t.List[exp.Condition] = [] 412 series_alias = new_name(taken_select_names, "pos") 413 series = exp.alias_( 414 exp.Unnest( 415 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 416 ), 417 new_name(taken_source_names, "_u"), 418 table=[series_alias], 419 ) 420 421 # we use list here because expression.selects is mutated inside the loop 422 for select in list(expression.selects): 423 explode = select.find(exp.Explode) 424 425 if explode: 426 pos_alias = "" 427 explode_alias = "" 428 429 if isinstance(select, exp.Alias): 430 explode_alias = select.args["alias"] 431 alias = select 432 elif isinstance(select, exp.Aliases): 433 pos_alias = select.aliases[0] 434 explode_alias = select.aliases[1] 435 alias = select.replace(exp.alias_(select.this, "", copy=False)) 436 else: 437 alias = select.replace(exp.alias_(select, "")) 438 explode = alias.find(exp.Explode) 439 assert explode 440 441 is_posexplode = isinstance(explode, exp.Posexplode) 442 explode_arg = explode.this 443 444 if isinstance(explode, exp.ExplodeOuter): 445 bracket = explode_arg[0] 446 bracket.set("safe", True) 447 bracket.set("offset", True) 448 explode_arg = exp.func( 449 "IF", 450 exp.func( 451 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 452 ).eq(0), 453 exp.array(bracket, copy=False), 454 explode_arg, 455 ) 456 457 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 458 if isinstance(explode_arg, exp.Column): 459 taken_select_names.add(explode_arg.output_name) 460 461 unnest_source_alias = new_name(taken_source_names, "_u") 462 463 if not explode_alias: 464 explode_alias = new_name(taken_select_names, "col") 465 466 if is_posexplode: 467 pos_alias = new_name(taken_select_names, "pos") 468 469 if not pos_alias: 470 pos_alias = new_name(taken_select_names, "pos") 471 472 alias.set("alias", exp.to_identifier(explode_alias)) 473 474 series_table_alias = series.args["alias"].this 475 column = exp.If( 476 this=exp.column(series_alias, table=series_table_alias).eq( 477 exp.column(pos_alias, table=unnest_source_alias) 478 ), 479 true=exp.column(explode_alias, table=unnest_source_alias), 480 ) 481 482 explode.replace(column) 483 484 if is_posexplode: 485 expressions = expression.expressions 486 expressions.insert( 487 expressions.index(alias) + 1, 488 exp.If( 489 this=exp.column(series_alias, table=series_table_alias).eq( 490 exp.column(pos_alias, table=unnest_source_alias) 491 ), 492 true=exp.column(pos_alias, table=unnest_source_alias), 493 ).as_(pos_alias), 494 ) 495 expression.set("expressions", expressions) 496 497 if not arrays: 498 if expression.args.get("from"): 499 expression.join(series, copy=False, join_type="CROSS") 500 else: 501 expression.from_(series, copy=False) 502 503 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 504 arrays.append(size) 505 506 # trino doesn't support left join unnest with on conditions 507 # if it did, this would be much simpler 508 expression.join( 509 exp.alias_( 510 exp.Unnest( 511 expressions=[explode_arg.copy()], 512 offset=exp.to_identifier(pos_alias), 513 ), 514 unnest_source_alias, 515 table=[explode_alias], 516 ), 517 join_type="CROSS", 518 copy=False, 519 ) 520 521 if index_offset != 1: 522 size = size - 1 523 524 expression.where( 525 exp.column(series_alias, table=series_table_alias) 526 .eq(exp.column(pos_alias, table=unnest_source_alias)) 527 .or_( 528 (exp.column(series_alias, table=series_table_alias) > size).and_( 529 exp.column(pos_alias, table=unnest_source_alias).eq(size) 530 ) 531 ), 532 copy=False, 533 ) 534 535 if arrays: 536 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 537 538 if index_offset != 1: 539 end = end - (1 - index_offset) 540 series.expressions[0].set("end", end) 541 542 return expression 543 544 return _explode_projection_to_unnest
Convert explode/posexplode projections into unnests.
547def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 548 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 549 if ( 550 isinstance(expression, exp.PERCENTILES) 551 and not isinstance(expression.parent, exp.WithinGroup) 552 and expression.expression 553 ): 554 column = expression.this.pop() 555 expression.set("this", expression.expression.pop()) 556 order = exp.Order(expressions=[exp.Ordered(this=column)]) 557 expression = exp.WithinGroup(this=expression, expression=order) 558 559 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
562def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 563 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 564 if ( 565 isinstance(expression, exp.WithinGroup) 566 and isinstance(expression.this, exp.PERCENTILES) 567 and isinstance(expression.expression, exp.Order) 568 ): 569 quantile = expression.this.this 570 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 571 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 572 573 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
576def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 577 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 578 if isinstance(expression, exp.With) and expression.recursive: 579 next_name = name_sequence("_c_") 580 581 for cte in expression.expressions: 582 if not cte.args["alias"].columns: 583 query = cte.this 584 if isinstance(query, exp.SetOperation): 585 query = query.this 586 587 cte.args["alias"].set( 588 "columns", 589 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 590 ) 591 592 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
595def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 596 """Replace 'epoch' in casts by the equivalent date literal.""" 597 if ( 598 isinstance(expression, (exp.Cast, exp.TryCast)) 599 and expression.name.lower() == "epoch" 600 and expression.to.this in exp.DataType.TEMPORAL_TYPES 601 ): 602 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 603 604 return expression
Replace 'epoch' in casts by the equivalent date literal.
607def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 608 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 609 if isinstance(expression, exp.Select): 610 for join in expression.args.get("joins") or []: 611 on = join.args.get("on") 612 if on and join.kind in ("SEMI", "ANTI"): 613 subquery = exp.select("1").from_(join.this).where(on) 614 exists = exp.Exists(this=subquery) 615 if join.kind == "ANTI": 616 exists = exists.not_(copy=False) 617 618 join.pop() 619 expression.where(exists, copy=False) 620 621 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
624def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 625 """ 626 Converts a query with a FULL OUTER join to a union of identical queries that 627 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 628 for queries that have a single FULL OUTER join. 629 """ 630 if isinstance(expression, exp.Select): 631 full_outer_joins = [ 632 (index, join) 633 for index, join in enumerate(expression.args.get("joins") or []) 634 if join.side == "FULL" 635 ] 636 637 if len(full_outer_joins) == 1: 638 expression_copy = expression.copy() 639 expression.set("limit", None) 640 index, full_outer_join = full_outer_joins[0] 641 642 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 643 join_conditions = full_outer_join.args.get("on") or exp.and_( 644 *[ 645 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 646 for col in full_outer_join.args.get("using") 647 ] 648 ) 649 650 full_outer_join.set("side", "left") 651 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 652 expression_copy.args["joins"][index].set("side", "right") 653 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 654 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 655 expression.args.pop("order", None) # remove order by from LEFT side 656 657 return exp.union(expression, expression_copy, copy=False, distinct=False) 658 659 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
662def move_ctes_to_top_level(expression: E) -> E: 663 """ 664 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 665 defined at the top-level, so for example queries like: 666 667 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 668 669 are invalid in those dialects. This transformation can be used to ensure all CTEs are 670 moved to the top level so that the final SQL code is valid from a syntax standpoint. 671 672 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 673 """ 674 top_level_with = expression.args.get("with") 675 for inner_with in expression.find_all(exp.With): 676 if inner_with.parent is expression: 677 continue 678 679 if not top_level_with: 680 top_level_with = inner_with.pop() 681 expression.set("with", top_level_with) 682 else: 683 if inner_with.recursive: 684 top_level_with.set("recursive", True) 685 686 parent_cte = inner_with.find_ancestor(exp.CTE) 687 inner_with.pop() 688 689 if parent_cte: 690 i = top_level_with.expressions.index(parent_cte) 691 top_level_with.expressions[i:i] = inner_with.expressions 692 top_level_with.set("expressions", top_level_with.expressions) 693 else: 694 top_level_with.set( 695 "expressions", top_level_with.expressions + inner_with.expressions 696 ) 697 698 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
701def ensure_bools(expression: exp.Expression) -> exp.Expression: 702 """Converts numeric values used in conditions into explicit boolean expressions.""" 703 from sqlglot.optimizer.canonicalize import ensure_bools 704 705 def _ensure_bool(node: exp.Expression) -> None: 706 if ( 707 node.is_number 708 or ( 709 not isinstance(node, exp.SubqueryPredicate) 710 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 711 ) 712 or (isinstance(node, exp.Column) and not node.type) 713 ): 714 node.replace(node.neq(0)) 715 716 for node in expression.walk(): 717 ensure_bools(node, _ensure_bool) 718 719 return expression
Converts numeric values used in conditions into explicit boolean expressions.
740def ctas_with_tmp_tables_to_create_tmp_view( 741 expression: exp.Expression, 742 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 743) -> exp.Expression: 744 assert isinstance(expression, exp.Create) 745 properties = expression.args.get("properties") 746 temporary = any( 747 isinstance(prop, exp.TemporaryProperty) 748 for prop in (properties.expressions if properties else []) 749 ) 750 751 # CTAS with temp tables map to CREATE TEMPORARY VIEW 752 if expression.kind == "TABLE" and temporary: 753 if expression.expression: 754 return exp.Create( 755 kind="TEMPORARY VIEW", 756 this=expression.this, 757 expression=expression.expression, 758 ) 759 return tmp_storage_provider(expression) 760 761 return expression
764def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 765 """ 766 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 767 PARTITIONED BY value is an array of column names, they are transformed into a schema. 768 The corresponding columns are removed from the create statement. 769 """ 770 assert isinstance(expression, exp.Create) 771 has_schema = isinstance(expression.this, exp.Schema) 772 is_partitionable = expression.kind in {"TABLE", "VIEW"} 773 774 if has_schema and is_partitionable: 775 prop = expression.find(exp.PartitionedByProperty) 776 if prop and prop.this and not isinstance(prop.this, exp.Schema): 777 schema = expression.this 778 columns = {v.name.upper() for v in prop.this.expressions} 779 partitions = [col for col in schema.expressions if col.name.upper() in columns] 780 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 781 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 782 expression.set("this", schema) 783 784 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
787def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 788 """ 789 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 790 791 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 792 """ 793 assert isinstance(expression, exp.Create) 794 prop = expression.find(exp.PartitionedByProperty) 795 if ( 796 prop 797 and prop.this 798 and isinstance(prop.this, exp.Schema) 799 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 800 ): 801 prop_this = exp.Tuple( 802 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 803 ) 804 schema = expression.this 805 for e in prop.this.expressions: 806 schema.append("expressions", e) 807 prop.set("this", prop_this) 808 809 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
812def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 813 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 814 if isinstance(expression, exp.Struct): 815 expression.set( 816 "expressions", 817 [ 818 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 819 for e in expression.expressions 820 ], 821 ) 822 823 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
826def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 827 """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178 828 829 1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax. 830 831 2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view. 832 833 The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query. 834 835 You cannot use the (+) operator to outer-join a table to itself, although self joins are valid. 836 837 The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator. 838 839 A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator. 840 841 A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression. 842 843 A WHERE condition cannot compare any column marked with the (+) operator with a subquery. 844 845 -- example with WHERE 846 SELECT d.department_name, sum(e.salary) as total_salary 847 FROM departments d, employees e 848 WHERE e.department_id(+) = d.department_id 849 group by department_name 850 851 -- example of left correlation in select 852 SELECT d.department_name, ( 853 SELECT SUM(e.salary) 854 FROM employees e 855 WHERE e.department_id(+) = d.department_id) AS total_salary 856 FROM departments d; 857 858 -- example of left correlation in from 859 SELECT d.department_name, t.total_salary 860 FROM departments d, ( 861 SELECT SUM(e.salary) AS total_salary 862 FROM employees e 863 WHERE e.department_id(+) = d.department_id 864 ) t 865 """ 866 867 from sqlglot.optimizer.scope import traverse_scope 868 from sqlglot.optimizer.normalize import normalize, normalized 869 from collections import defaultdict 870 871 # we go in reverse to check the main query for left correlation 872 for scope in reversed(traverse_scope(expression)): 873 query = scope.expression 874 875 where = query.args.get("where") 876 joins = query.args.get("joins", []) 877 878 # knockout: we do not support left correlation (see point 2) 879 assert not scope.is_correlated_subquery, "Correlated queries are not supported" 880 881 # nothing to do - we check it here after knockout above 882 if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)): 883 continue 884 885 # make sure we have AND of ORs to have clear join terms 886 where = normalize(where.this) 887 assert normalized(where), "Cannot normalize JOIN predicates" 888 889 joins_ons = defaultdict(list) # dict of {name: list of join AND conditions} 890 for cond in [where] if not isinstance(where, exp.And) else where.flatten(): 891 join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")] 892 893 left_join_table = set(col.table for col in join_cols) 894 if not left_join_table: 895 continue 896 897 assert not ( 898 len(left_join_table) > 1 899 ), "Cannot combine JOIN predicates from different tables" 900 901 for col in join_cols: 902 col.set("join_mark", False) 903 904 joins_ons[left_join_table.pop()].append(cond) 905 906 old_joins = {join.alias_or_name: join for join in joins} 907 new_joins = {} 908 query_from = query.args["from"] 909 910 for table, predicates in joins_ons.items(): 911 join_what = old_joins.get(table, query_from).this.copy() 912 new_joins[join_what.alias_or_name] = exp.Join( 913 this=join_what, on=exp.and_(*predicates), kind="LEFT" 914 ) 915 916 for p in predicates: 917 while isinstance(p.parent, exp.Paren): 918 p.parent.replace(p) 919 920 parent = p.parent 921 p.pop() 922 if isinstance(parent, exp.Binary): 923 parent.replace(parent.right if parent.left is None else parent.left) 924 elif isinstance(parent, exp.Where): 925 parent.pop() 926 927 if query_from.alias_or_name in new_joins: 928 only_old_joins = old_joins.keys() - new_joins.keys() 929 assert ( 930 len(only_old_joins) >= 1 931 ), "Cannot determine which table to use in the new FROM clause" 932 933 new_from_name = list(only_old_joins)[0] 934 query.set("from", exp.From(this=old_joins[new_from_name].this)) 935 936 if new_joins: 937 for n, j in old_joins.items(): # preserve any other joins 938 if n not in new_joins and n != query.args["from"].name: 939 if not j.kind: 940 j.set("kind", "CROSS") 941 new_joins[n] = j 942 query.set("joins", list(new_joins.values())) 943 944 return expression
https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
-- example with WHERE SELECT d.department_name, sum(e.salary) as total_salary FROM departments d, employees e WHERE e.department_id(+) = d.department_id group by department_name
-- example of left correlation in select SELECT d.department_name, ( SELECT SUM(e.salary) FROM employees e WHERE e.department_id(+) = d.department_id) AS total_salary FROM departments d;
-- example of left correlation in from SELECT d.department_name, t.total_salary FROM departments d, ( SELECT SUM(e.salary) AS total_salary FROM employees e WHERE e.department_id(+) = d.department_id ) t
947def any_to_exists(expression: exp.Expression) -> exp.Expression: 948 """ 949 Transform ANY operator to Spark's EXISTS 950 951 For example, 952 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 953 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 954 955 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 956 transformation 957 """ 958 if isinstance(expression, exp.Select): 959 for any_expr in expression.find_all(exp.Any): 960 this = any_expr.this 961 if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)): 962 continue 963 964 binop = any_expr.parent 965 if isinstance(binop, exp.Binary): 966 lambda_arg = exp.to_identifier("x") 967 any_expr.replace(lambda_arg) 968 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 969 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 970 971 return expression
Transform ANY operator to Spark's EXISTS
For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation
974def eliminate_window_clause(expression: exp.Expression) -> exp.Expression: 975 """Eliminates the `WINDOW` query clause by inling each named window.""" 976 if isinstance(expression, exp.Select) and expression.args.get("windows"): 977 from sqlglot.optimizer.scope import find_all_in_scope 978 979 windows = expression.args["windows"] 980 expression.set("windows", None) 981 982 window_expression: t.Dict[str, exp.Expression] = {} 983 984 def _inline_inherited_window(window: exp.Expression) -> None: 985 inherited_window = window_expression.get(window.alias.lower()) 986 if not inherited_window: 987 return 988 989 window.set("alias", None) 990 for key in ("partition_by", "order", "spec"): 991 arg = inherited_window.args.get(key) 992 if arg: 993 window.set(key, arg.copy()) 994 995 for window in windows: 996 _inline_inherited_window(window) 997 window_expression[window.name.lower()] = window 998 999 for window in find_all_in_scope(expression, exp.Window): 1000 _inline_inherited_window(window) 1001 1002 return expression
Eliminates the WINDOW
query clause by inling each named window.