sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot._typing import E 12 from sqlglot.generator import Generator 13 14 15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql 62 63 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression 119 120 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression 132 133 134def unalias_group(expression: exp.Expression) -> exp.Expression: 135 """ 136 Replace references to select aliases in GROUP BY clauses. 137 138 Example: 139 >>> import sqlglot 140 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 141 'SELECT a AS b FROM x GROUP BY 1' 142 143 Args: 144 expression: the expression that will be transformed. 145 146 Returns: 147 The transformed expression. 148 """ 149 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 150 aliased_selects = { 151 e.alias: i 152 for i, e in enumerate(expression.parent.expressions, start=1) 153 if isinstance(e, exp.Alias) 154 } 155 156 for group_by in expression.expressions: 157 if ( 158 isinstance(group_by, exp.Column) 159 and not group_by.table 160 and group_by.name in aliased_selects 161 ): 162 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 163 164 return expression 165 166 167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 168 """ 169 Convert SELECT DISTINCT ON statements to a subquery with a window function. 170 171 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 172 173 Args: 174 expression: the expression that will be transformed. 175 176 Returns: 177 The transformed expression. 178 """ 179 if ( 180 isinstance(expression, exp.Select) 181 and expression.args.get("distinct") 182 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 183 ): 184 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 185 186 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 187 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 188 189 order = expression.args.get("order") 190 if order: 191 window.set("order", order.pop()) 192 else: 193 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 194 195 window = exp.alias_(window, row_number_window_alias) 196 expression.select(window, copy=False) 197 198 # We add aliases to the projections so that we can safely reference them in the outer query 199 new_selects = [] 200 taken_names = {row_number_window_alias} 201 for select in expression.selects[:-1]: 202 if select.is_star: 203 new_selects = [exp.Star()] 204 break 205 206 if not isinstance(select, exp.Alias): 207 alias = find_new_name(taken_names, select.output_name or "_col") 208 quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None 209 select = select.replace(exp.alias_(select, alias, quoted=quoted)) 210 211 taken_names.add(select.output_name) 212 new_selects.append(select.args["alias"]) 213 214 return ( 215 exp.select(*new_selects, copy=False) 216 .from_(expression.subquery("_t", copy=False), copy=False) 217 .where(exp.column(row_number_window_alias).eq(1), copy=False) 218 ) 219 220 return expression 221 222 223def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 224 """ 225 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 226 227 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 228 https://docs.snowflake.com/en/sql-reference/constructs/qualify 229 230 Some dialects don't support window functions in the WHERE clause, so we need to include them as 231 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 232 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 233 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 234 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 235 corresponding expression to avoid creating invalid column references. 236 """ 237 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 238 taken = set(expression.named_selects) 239 for select in expression.selects: 240 if not select.alias_or_name: 241 alias = find_new_name(taken, "_c") 242 select.replace(exp.alias_(select, alias)) 243 taken.add(alias) 244 245 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 246 alias_or_name = select.alias_or_name 247 identifier = select.args.get("alias") or select.this 248 if isinstance(identifier, exp.Identifier): 249 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 250 return alias_or_name 251 252 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 253 qualify_filters = expression.args["qualify"].pop().this 254 expression_by_alias = { 255 select.alias: select.this 256 for select in expression.selects 257 if isinstance(select, exp.Alias) 258 } 259 260 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 261 for select_candidate in list(qualify_filters.find_all(select_candidates)): 262 if isinstance(select_candidate, exp.Window): 263 if expression_by_alias: 264 for column in select_candidate.find_all(exp.Column): 265 expr = expression_by_alias.get(column.name) 266 if expr: 267 column.replace(expr) 268 269 alias = find_new_name(expression.named_selects, "_w") 270 expression.select(exp.alias_(select_candidate, alias), copy=False) 271 column = exp.column(alias) 272 273 if isinstance(select_candidate.parent, exp.Qualify): 274 qualify_filters = column 275 else: 276 select_candidate.replace(column) 277 elif select_candidate.name not in expression.named_selects: 278 expression.select(select_candidate.copy(), copy=False) 279 280 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 281 qualify_filters, copy=False 282 ) 283 284 return expression 285 286 287def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 288 """ 289 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 290 other expressions. This transforms removes the precision from parameterized types in expressions. 291 """ 292 for node in expression.find_all(exp.DataType): 293 node.set( 294 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 295 ) 296 297 return expression 298 299 300def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 301 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 302 from sqlglot.optimizer.scope import find_all_in_scope 303 304 if isinstance(expression, exp.Select): 305 unnest_aliases = { 306 unnest.alias 307 for unnest in find_all_in_scope(expression, exp.Unnest) 308 if isinstance(unnest.parent, (exp.From, exp.Join)) 309 } 310 if unnest_aliases: 311 for column in expression.find_all(exp.Column): 312 if column.table in unnest_aliases: 313 column.set("table", None) 314 elif column.db in unnest_aliases: 315 column.set("db", None) 316 317 return expression 318 319 320def unnest_to_explode( 321 expression: exp.Expression, 322 unnest_using_arrays_zip: bool = True, 323) -> exp.Expression: 324 """Convert cross join unnest into lateral view explode.""" 325 326 def _unnest_zip_exprs( 327 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 328 ) -> t.List[exp.Expression]: 329 if has_multi_expr: 330 if not unnest_using_arrays_zip: 331 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 332 333 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 334 zip_exprs: t.List[exp.Expression] = [ 335 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 336 ] 337 u.set("expressions", zip_exprs) 338 return zip_exprs 339 return unnest_exprs 340 341 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 342 if u.args.get("offset"): 343 return exp.Posexplode 344 return exp.Inline if has_multi_expr else exp.Explode 345 346 if isinstance(expression, exp.Select): 347 from_ = expression.args.get("from") 348 349 if from_ and isinstance(from_.this, exp.Unnest): 350 unnest = from_.this 351 alias = unnest.args.get("alias") 352 exprs = unnest.expressions 353 has_multi_expr = len(exprs) > 1 354 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 355 356 unnest.replace( 357 exp.Table( 358 this=_udtf_type(unnest, has_multi_expr)( 359 this=this, 360 expressions=expressions, 361 ), 362 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 363 ) 364 ) 365 366 joins = expression.args.get("joins") or [] 367 for join in list(joins): 368 join_expr = join.this 369 370 is_lateral = isinstance(join_expr, exp.Lateral) 371 372 unnest = join_expr.this if is_lateral else join_expr 373 374 if isinstance(unnest, exp.Unnest): 375 if is_lateral: 376 alias = join_expr.args.get("alias") 377 else: 378 alias = unnest.args.get("alias") 379 exprs = unnest.expressions 380 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 381 has_multi_expr = len(exprs) > 1 382 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 383 384 joins.remove(join) 385 386 alias_cols = alias.columns if alias else [] 387 388 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 389 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 390 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 391 392 if not has_multi_expr and len(alias_cols) not in (1, 2): 393 raise UnsupportedError( 394 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 395 ) 396 397 for e, column in zip(exprs, alias_cols): 398 expression.append( 399 "laterals", 400 exp.Lateral( 401 this=_udtf_type(unnest, has_multi_expr)(this=e), 402 view=True, 403 alias=exp.TableAlias( 404 this=alias.this, # type: ignore 405 columns=alias_cols, 406 ), 407 ), 408 ) 409 410 return expression 411 412 413def explode_projection_to_unnest( 414 index_offset: int = 0, 415) -> t.Callable[[exp.Expression], exp.Expression]: 416 """Convert explode/posexplode projections into unnests.""" 417 418 def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression: 419 if isinstance(expression, exp.Select): 420 from sqlglot.optimizer.scope import Scope 421 422 taken_select_names = set(expression.named_selects) 423 taken_source_names = {name for name, _ in Scope(expression).references} 424 425 def new_name(names: t.Set[str], name: str) -> str: 426 name = find_new_name(names, name) 427 names.add(name) 428 return name 429 430 arrays: t.List[exp.Condition] = [] 431 series_alias = new_name(taken_select_names, "pos") 432 series = exp.alias_( 433 exp.Unnest( 434 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 435 ), 436 new_name(taken_source_names, "_u"), 437 table=[series_alias], 438 ) 439 440 # we use list here because expression.selects is mutated inside the loop 441 for select in list(expression.selects): 442 explode = select.find(exp.Explode) 443 444 if explode: 445 pos_alias = "" 446 explode_alias = "" 447 448 if isinstance(select, exp.Alias): 449 explode_alias = select.args["alias"] 450 alias = select 451 elif isinstance(select, exp.Aliases): 452 pos_alias = select.aliases[0] 453 explode_alias = select.aliases[1] 454 alias = select.replace(exp.alias_(select.this, "", copy=False)) 455 else: 456 alias = select.replace(exp.alias_(select, "")) 457 explode = alias.find(exp.Explode) 458 assert explode 459 460 is_posexplode = isinstance(explode, exp.Posexplode) 461 explode_arg = explode.this 462 463 if isinstance(explode, exp.ExplodeOuter): 464 bracket = explode_arg[0] 465 bracket.set("safe", True) 466 bracket.set("offset", True) 467 explode_arg = exp.func( 468 "IF", 469 exp.func( 470 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 471 ).eq(0), 472 exp.array(bracket, copy=False), 473 explode_arg, 474 ) 475 476 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 477 if isinstance(explode_arg, exp.Column): 478 taken_select_names.add(explode_arg.output_name) 479 480 unnest_source_alias = new_name(taken_source_names, "_u") 481 482 if not explode_alias: 483 explode_alias = new_name(taken_select_names, "col") 484 485 if is_posexplode: 486 pos_alias = new_name(taken_select_names, "pos") 487 488 if not pos_alias: 489 pos_alias = new_name(taken_select_names, "pos") 490 491 alias.set("alias", exp.to_identifier(explode_alias)) 492 493 series_table_alias = series.args["alias"].this 494 column = exp.If( 495 this=exp.column(series_alias, table=series_table_alias).eq( 496 exp.column(pos_alias, table=unnest_source_alias) 497 ), 498 true=exp.column(explode_alias, table=unnest_source_alias), 499 ) 500 501 explode.replace(column) 502 503 if is_posexplode: 504 expressions = expression.expressions 505 expressions.insert( 506 expressions.index(alias) + 1, 507 exp.If( 508 this=exp.column(series_alias, table=series_table_alias).eq( 509 exp.column(pos_alias, table=unnest_source_alias) 510 ), 511 true=exp.column(pos_alias, table=unnest_source_alias), 512 ).as_(pos_alias), 513 ) 514 expression.set("expressions", expressions) 515 516 if not arrays: 517 if expression.args.get("from"): 518 expression.join(series, copy=False, join_type="CROSS") 519 else: 520 expression.from_(series, copy=False) 521 522 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 523 arrays.append(size) 524 525 # trino doesn't support left join unnest with on conditions 526 # if it did, this would be much simpler 527 expression.join( 528 exp.alias_( 529 exp.Unnest( 530 expressions=[explode_arg.copy()], 531 offset=exp.to_identifier(pos_alias), 532 ), 533 unnest_source_alias, 534 table=[explode_alias], 535 ), 536 join_type="CROSS", 537 copy=False, 538 ) 539 540 if index_offset != 1: 541 size = size - 1 542 543 expression.where( 544 exp.column(series_alias, table=series_table_alias) 545 .eq(exp.column(pos_alias, table=unnest_source_alias)) 546 .or_( 547 (exp.column(series_alias, table=series_table_alias) > size).and_( 548 exp.column(pos_alias, table=unnest_source_alias).eq(size) 549 ) 550 ), 551 copy=False, 552 ) 553 554 if arrays: 555 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 556 557 if index_offset != 1: 558 end = end - (1 - index_offset) 559 series.expressions[0].set("end", end) 560 561 return expression 562 563 return _explode_projection_to_unnest 564 565 566def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 567 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 568 if ( 569 isinstance(expression, exp.PERCENTILES) 570 and not isinstance(expression.parent, exp.WithinGroup) 571 and expression.expression 572 ): 573 column = expression.this.pop() 574 expression.set("this", expression.expression.pop()) 575 order = exp.Order(expressions=[exp.Ordered(this=column)]) 576 expression = exp.WithinGroup(this=expression, expression=order) 577 578 return expression 579 580 581def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 582 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 583 if ( 584 isinstance(expression, exp.WithinGroup) 585 and isinstance(expression.this, exp.PERCENTILES) 586 and isinstance(expression.expression, exp.Order) 587 ): 588 quantile = expression.this.this 589 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 590 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 591 592 return expression 593 594 595def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 596 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 597 if isinstance(expression, exp.With) and expression.recursive: 598 next_name = name_sequence("_c_") 599 600 for cte in expression.expressions: 601 if not cte.args["alias"].columns: 602 query = cte.this 603 if isinstance(query, exp.SetOperation): 604 query = query.this 605 606 cte.args["alias"].set( 607 "columns", 608 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 609 ) 610 611 return expression 612 613 614def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 615 """Replace 'epoch' in casts by the equivalent date literal.""" 616 if ( 617 isinstance(expression, (exp.Cast, exp.TryCast)) 618 and expression.name.lower() == "epoch" 619 and expression.to.this in exp.DataType.TEMPORAL_TYPES 620 ): 621 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 622 623 return expression 624 625 626def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 627 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 628 if isinstance(expression, exp.Select): 629 for join in expression.args.get("joins") or []: 630 on = join.args.get("on") 631 if on and join.kind in ("SEMI", "ANTI"): 632 subquery = exp.select("1").from_(join.this).where(on) 633 exists = exp.Exists(this=subquery) 634 if join.kind == "ANTI": 635 exists = exists.not_(copy=False) 636 637 join.pop() 638 expression.where(exists, copy=False) 639 640 return expression 641 642 643def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 644 """ 645 Converts a query with a FULL OUTER join to a union of identical queries that 646 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 647 for queries that have a single FULL OUTER join. 648 """ 649 if isinstance(expression, exp.Select): 650 full_outer_joins = [ 651 (index, join) 652 for index, join in enumerate(expression.args.get("joins") or []) 653 if join.side == "FULL" 654 ] 655 656 if len(full_outer_joins) == 1: 657 expression_copy = expression.copy() 658 expression.set("limit", None) 659 index, full_outer_join = full_outer_joins[0] 660 661 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 662 join_conditions = full_outer_join.args.get("on") or exp.and_( 663 *[ 664 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 665 for col in full_outer_join.args.get("using") 666 ] 667 ) 668 669 full_outer_join.set("side", "left") 670 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 671 expression_copy.args["joins"][index].set("side", "right") 672 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 673 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 674 expression.args.pop("order", None) # remove order by from LEFT side 675 676 return exp.union(expression, expression_copy, copy=False, distinct=False) 677 678 return expression 679 680 681def move_ctes_to_top_level(expression: E) -> E: 682 """ 683 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 684 defined at the top-level, so for example queries like: 685 686 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 687 688 are invalid in those dialects. This transformation can be used to ensure all CTEs are 689 moved to the top level so that the final SQL code is valid from a syntax standpoint. 690 691 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 692 """ 693 top_level_with = expression.args.get("with") 694 for inner_with in expression.find_all(exp.With): 695 if inner_with.parent is expression: 696 continue 697 698 if not top_level_with: 699 top_level_with = inner_with.pop() 700 expression.set("with", top_level_with) 701 else: 702 if inner_with.recursive: 703 top_level_with.set("recursive", True) 704 705 parent_cte = inner_with.find_ancestor(exp.CTE) 706 inner_with.pop() 707 708 if parent_cte: 709 i = top_level_with.expressions.index(parent_cte) 710 top_level_with.expressions[i:i] = inner_with.expressions 711 top_level_with.set("expressions", top_level_with.expressions) 712 else: 713 top_level_with.set( 714 "expressions", top_level_with.expressions + inner_with.expressions 715 ) 716 717 return expression 718 719 720def ensure_bools(expression: exp.Expression) -> exp.Expression: 721 """Converts numeric values used in conditions into explicit boolean expressions.""" 722 from sqlglot.optimizer.canonicalize import ensure_bools 723 724 def _ensure_bool(node: exp.Expression) -> None: 725 if ( 726 node.is_number 727 or ( 728 not isinstance(node, exp.SubqueryPredicate) 729 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 730 ) 731 or (isinstance(node, exp.Column) and not node.type) 732 ): 733 node.replace(node.neq(0)) 734 735 for node in expression.walk(): 736 ensure_bools(node, _ensure_bool) 737 738 return expression 739 740 741def unqualify_columns(expression: exp.Expression) -> exp.Expression: 742 for column in expression.find_all(exp.Column): 743 # We only wanna pop off the table, db, catalog args 744 for part in column.parts[:-1]: 745 part.pop() 746 747 return expression 748 749 750def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 751 assert isinstance(expression, exp.Create) 752 for constraint in expression.find_all(exp.UniqueColumnConstraint): 753 if constraint.parent: 754 constraint.parent.pop() 755 756 return expression 757 758 759def ctas_with_tmp_tables_to_create_tmp_view( 760 expression: exp.Expression, 761 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 762) -> exp.Expression: 763 assert isinstance(expression, exp.Create) 764 properties = expression.args.get("properties") 765 temporary = any( 766 isinstance(prop, exp.TemporaryProperty) 767 for prop in (properties.expressions if properties else []) 768 ) 769 770 # CTAS with temp tables map to CREATE TEMPORARY VIEW 771 if expression.kind == "TABLE" and temporary: 772 if expression.expression: 773 return exp.Create( 774 kind="TEMPORARY VIEW", 775 this=expression.this, 776 expression=expression.expression, 777 ) 778 return tmp_storage_provider(expression) 779 780 return expression 781 782 783def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 784 """ 785 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 786 PARTITIONED BY value is an array of column names, they are transformed into a schema. 787 The corresponding columns are removed from the create statement. 788 """ 789 assert isinstance(expression, exp.Create) 790 has_schema = isinstance(expression.this, exp.Schema) 791 is_partitionable = expression.kind in {"TABLE", "VIEW"} 792 793 if has_schema and is_partitionable: 794 prop = expression.find(exp.PartitionedByProperty) 795 if prop and prop.this and not isinstance(prop.this, exp.Schema): 796 schema = expression.this 797 columns = {v.name.upper() for v in prop.this.expressions} 798 partitions = [col for col in schema.expressions if col.name.upper() in columns] 799 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 800 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 801 expression.set("this", schema) 802 803 return expression 804 805 806def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 807 """ 808 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 809 810 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 811 """ 812 assert isinstance(expression, exp.Create) 813 prop = expression.find(exp.PartitionedByProperty) 814 if ( 815 prop 816 and prop.this 817 and isinstance(prop.this, exp.Schema) 818 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 819 ): 820 prop_this = exp.Tuple( 821 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 822 ) 823 schema = expression.this 824 for e in prop.this.expressions: 825 schema.append("expressions", e) 826 prop.set("this", prop_this) 827 828 return expression 829 830 831def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 832 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 833 if isinstance(expression, exp.Struct): 834 expression.set( 835 "expressions", 836 [ 837 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 838 for e in expression.expressions 839 ], 840 ) 841 842 return expression 843 844 845def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 846 """ 847 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 848 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 849 850 For example, 851 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 852 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 853 854 Args: 855 expression: The AST to remove join marks from. 856 857 Returns: 858 The AST with join marks removed. 859 """ 860 from sqlglot.optimizer.scope import traverse_scope 861 862 for scope in traverse_scope(expression): 863 query = scope.expression 864 865 where = query.args.get("where") 866 joins = query.args.get("joins") 867 868 if not where or not joins: 869 continue 870 871 query_from = query.args["from"] 872 873 # These keep track of the joins to be replaced 874 new_joins: t.Dict[str, exp.Join] = {} 875 old_joins = {join.alias_or_name: join for join in joins} 876 877 for column in scope.columns: 878 if not column.args.get("join_mark"): 879 continue 880 881 predicate = column.find_ancestor(exp.Predicate, exp.Select) 882 assert isinstance( 883 predicate, exp.Binary 884 ), "Columns can only be marked with (+) when involved in a binary operation" 885 886 predicate_parent = predicate.parent 887 join_predicate = predicate.pop() 888 889 left_columns = [ 890 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 891 ] 892 right_columns = [ 893 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 894 ] 895 896 assert not ( 897 left_columns and right_columns 898 ), "The (+) marker cannot appear in both sides of a binary predicate" 899 900 marked_column_tables = set() 901 for col in left_columns or right_columns: 902 table = col.table 903 assert table, f"Column {col} needs to be qualified with a table" 904 905 col.set("join_mark", False) 906 marked_column_tables.add(table) 907 908 assert ( 909 len(marked_column_tables) == 1 910 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 911 912 # Add predicate if join already copied, or add join if it is new 913 join_this = old_joins.get(col.table, query_from).this 914 existing_join = new_joins.get(join_this.alias_or_name) 915 if existing_join: 916 existing_join.set("on", exp.and_(existing_join.args["on"], join_predicate)) 917 else: 918 new_joins[join_this.alias_or_name] = exp.Join( 919 this=join_this.copy(), on=join_predicate.copy(), kind="LEFT" 920 ) 921 922 # If the parent of the target predicate is a binary node, then it now has only one child 923 if isinstance(predicate_parent, exp.Binary): 924 if predicate_parent.left is None: 925 predicate_parent.replace(predicate_parent.right) 926 else: 927 predicate_parent.replace(predicate_parent.left) 928 929 if query_from.alias_or_name in new_joins: 930 only_old_joins = old_joins.keys() - new_joins.keys() 931 assert ( 932 len(only_old_joins) >= 1 933 ), "Cannot determine which table to use in the new FROM clause" 934 935 new_from_name = list(only_old_joins)[0] 936 query.set("from", exp.From(this=old_joins[new_from_name].this)) 937 938 if new_joins: 939 query.set("joins", list(new_joins.values())) 940 941 if not where.this: 942 where.pop() 943 944 return expression 945 946 947def any_to_exists(expression: exp.Expression) -> exp.Expression: 948 """ 949 Transform ANY operator to Spark's EXISTS 950 951 For example, 952 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 953 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 954 955 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 956 transformation 957 """ 958 if isinstance(expression, exp.Select): 959 for any in expression.find_all(exp.Any): 960 this = any.this 961 if isinstance(this, exp.Query): 962 continue 963 964 binop = any.parent 965 if isinstance(binop, exp.Binary): 966 lambda_arg = exp.to_identifier("x") 967 any.replace(lambda_arg) 968 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 969 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 970 971 return expression
16def preprocess( 17 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 18) -> t.Callable[[Generator, exp.Expression], str]: 19 """ 20 Creates a new transform by chaining a sequence of transformations and converts the resulting 21 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 22 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 23 24 Args: 25 transforms: sequence of transform functions. These will be called in order. 26 27 Returns: 28 Function that can be used as a generator transform. 29 """ 30 31 def _to_sql(self, expression: exp.Expression) -> str: 32 expression_type = type(expression) 33 34 try: 35 expression = transforms[0](expression) 36 for transform in transforms[1:]: 37 expression = transform(expression) 38 except UnsupportedError as unsupported_error: 39 self.unsupported(str(unsupported_error)) 40 41 _sql_handler = getattr(self, expression.key + "_sql", None) 42 if _sql_handler: 43 return _sql_handler(expression) 44 45 transforms_handler = self.TRANSFORMS.get(type(expression)) 46 if transforms_handler: 47 if expression_type is type(expression): 48 if isinstance(expression, exp.Func): 49 return self.function_fallback_sql(expression) 50 51 # Ensures we don't enter an infinite loop. This can happen when the original expression 52 # has the same type as the final expression and there's no _sql method available for it, 53 # because then it'd re-enter _to_sql. 54 raise ValueError( 55 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 56 ) 57 58 return transforms_handler(self, expression) 59 60 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 61 62 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 66 if isinstance(expression, exp.Select): 67 count = 0 68 recursive_ctes = [] 69 70 for unnest in expression.find_all(exp.Unnest): 71 if ( 72 not isinstance(unnest.parent, (exp.From, exp.Join)) 73 or len(unnest.expressions) != 1 74 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 75 ): 76 continue 77 78 generate_date_array = unnest.expressions[0] 79 start = generate_date_array.args.get("start") 80 end = generate_date_array.args.get("end") 81 step = generate_date_array.args.get("step") 82 83 if not start or not end or not isinstance(step, exp.Interval): 84 continue 85 86 alias = unnest.args.get("alias") 87 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 88 89 start = exp.cast(start, "date") 90 date_add = exp.func( 91 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 92 ) 93 cast_date_add = exp.cast(date_add, "date") 94 95 cte_name = "_generated_dates" + (f"_{count}" if count else "") 96 97 base_query = exp.select(start.as_(column_name)) 98 recursive_query = ( 99 exp.select(cast_date_add) 100 .from_(cte_name) 101 .where(cast_date_add <= exp.cast(end, "date")) 102 ) 103 cte_query = base_query.union(recursive_query, distinct=False) 104 105 generate_dates_query = exp.select(column_name).from_(cte_name) 106 unnest.replace(generate_dates_query.subquery(cte_name)) 107 108 recursive_ctes.append( 109 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 110 ) 111 count += 1 112 113 if recursive_ctes: 114 with_expression = expression.args.get("with") or exp.With() 115 with_expression.set("recursive", True) 116 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 117 expression.set("with", with_expression) 118 119 return expression
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 123 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 124 this = expression.this 125 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 126 unnest = exp.Unnest(expressions=[this]) 127 if expression.alias: 128 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 129 130 return unnest 131 132 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
135def unalias_group(expression: exp.Expression) -> exp.Expression: 136 """ 137 Replace references to select aliases in GROUP BY clauses. 138 139 Example: 140 >>> import sqlglot 141 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 142 'SELECT a AS b FROM x GROUP BY 1' 143 144 Args: 145 expression: the expression that will be transformed. 146 147 Returns: 148 The transformed expression. 149 """ 150 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 151 aliased_selects = { 152 e.alias: i 153 for i, e in enumerate(expression.parent.expressions, start=1) 154 if isinstance(e, exp.Alias) 155 } 156 157 for group_by in expression.expressions: 158 if ( 159 isinstance(group_by, exp.Column) 160 and not group_by.table 161 and group_by.name in aliased_selects 162 ): 163 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 164 165 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 169 """ 170 Convert SELECT DISTINCT ON statements to a subquery with a window function. 171 172 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 173 174 Args: 175 expression: the expression that will be transformed. 176 177 Returns: 178 The transformed expression. 179 """ 180 if ( 181 isinstance(expression, exp.Select) 182 and expression.args.get("distinct") 183 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 184 ): 185 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 186 187 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 188 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 189 190 order = expression.args.get("order") 191 if order: 192 window.set("order", order.pop()) 193 else: 194 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 195 196 window = exp.alias_(window, row_number_window_alias) 197 expression.select(window, copy=False) 198 199 # We add aliases to the projections so that we can safely reference them in the outer query 200 new_selects = [] 201 taken_names = {row_number_window_alias} 202 for select in expression.selects[:-1]: 203 if select.is_star: 204 new_selects = [exp.Star()] 205 break 206 207 if not isinstance(select, exp.Alias): 208 alias = find_new_name(taken_names, select.output_name or "_col") 209 quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None 210 select = select.replace(exp.alias_(select, alias, quoted=quoted)) 211 212 taken_names.add(select.output_name) 213 new_selects.append(select.args["alias"]) 214 215 return ( 216 exp.select(*new_selects, copy=False) 217 .from_(expression.subquery("_t", copy=False), copy=False) 218 .where(exp.column(row_number_window_alias).eq(1), copy=False) 219 ) 220 221 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
224def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 225 """ 226 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 227 228 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 229 https://docs.snowflake.com/en/sql-reference/constructs/qualify 230 231 Some dialects don't support window functions in the WHERE clause, so we need to include them as 232 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 233 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 234 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 235 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 236 corresponding expression to avoid creating invalid column references. 237 """ 238 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 239 taken = set(expression.named_selects) 240 for select in expression.selects: 241 if not select.alias_or_name: 242 alias = find_new_name(taken, "_c") 243 select.replace(exp.alias_(select, alias)) 244 taken.add(alias) 245 246 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 247 alias_or_name = select.alias_or_name 248 identifier = select.args.get("alias") or select.this 249 if isinstance(identifier, exp.Identifier): 250 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 251 return alias_or_name 252 253 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 254 qualify_filters = expression.args["qualify"].pop().this 255 expression_by_alias = { 256 select.alias: select.this 257 for select in expression.selects 258 if isinstance(select, exp.Alias) 259 } 260 261 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 262 for select_candidate in list(qualify_filters.find_all(select_candidates)): 263 if isinstance(select_candidate, exp.Window): 264 if expression_by_alias: 265 for column in select_candidate.find_all(exp.Column): 266 expr = expression_by_alias.get(column.name) 267 if expr: 268 column.replace(expr) 269 270 alias = find_new_name(expression.named_selects, "_w") 271 expression.select(exp.alias_(select_candidate, alias), copy=False) 272 column = exp.column(alias) 273 274 if isinstance(select_candidate.parent, exp.Qualify): 275 qualify_filters = column 276 else: 277 select_candidate.replace(column) 278 elif select_candidate.name not in expression.named_selects: 279 expression.select(select_candidate.copy(), copy=False) 280 281 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 282 qualify_filters, copy=False 283 ) 284 285 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
288def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 289 """ 290 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 291 other expressions. This transforms removes the precision from parameterized types in expressions. 292 """ 293 for node in expression.find_all(exp.DataType): 294 node.set( 295 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 296 ) 297 298 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
301def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 302 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 303 from sqlglot.optimizer.scope import find_all_in_scope 304 305 if isinstance(expression, exp.Select): 306 unnest_aliases = { 307 unnest.alias 308 for unnest in find_all_in_scope(expression, exp.Unnest) 309 if isinstance(unnest.parent, (exp.From, exp.Join)) 310 } 311 if unnest_aliases: 312 for column in expression.find_all(exp.Column): 313 if column.table in unnest_aliases: 314 column.set("table", None) 315 elif column.db in unnest_aliases: 316 column.set("db", None) 317 318 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
321def unnest_to_explode( 322 expression: exp.Expression, 323 unnest_using_arrays_zip: bool = True, 324) -> exp.Expression: 325 """Convert cross join unnest into lateral view explode.""" 326 327 def _unnest_zip_exprs( 328 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 329 ) -> t.List[exp.Expression]: 330 if has_multi_expr: 331 if not unnest_using_arrays_zip: 332 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 333 334 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 335 zip_exprs: t.List[exp.Expression] = [ 336 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 337 ] 338 u.set("expressions", zip_exprs) 339 return zip_exprs 340 return unnest_exprs 341 342 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 343 if u.args.get("offset"): 344 return exp.Posexplode 345 return exp.Inline if has_multi_expr else exp.Explode 346 347 if isinstance(expression, exp.Select): 348 from_ = expression.args.get("from") 349 350 if from_ and isinstance(from_.this, exp.Unnest): 351 unnest = from_.this 352 alias = unnest.args.get("alias") 353 exprs = unnest.expressions 354 has_multi_expr = len(exprs) > 1 355 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 356 357 unnest.replace( 358 exp.Table( 359 this=_udtf_type(unnest, has_multi_expr)( 360 this=this, 361 expressions=expressions, 362 ), 363 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 364 ) 365 ) 366 367 joins = expression.args.get("joins") or [] 368 for join in list(joins): 369 join_expr = join.this 370 371 is_lateral = isinstance(join_expr, exp.Lateral) 372 373 unnest = join_expr.this if is_lateral else join_expr 374 375 if isinstance(unnest, exp.Unnest): 376 if is_lateral: 377 alias = join_expr.args.get("alias") 378 else: 379 alias = unnest.args.get("alias") 380 exprs = unnest.expressions 381 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 382 has_multi_expr = len(exprs) > 1 383 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 384 385 joins.remove(join) 386 387 alias_cols = alias.columns if alias else [] 388 389 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 390 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 391 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 392 393 if not has_multi_expr and len(alias_cols) not in (1, 2): 394 raise UnsupportedError( 395 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 396 ) 397 398 for e, column in zip(exprs, alias_cols): 399 expression.append( 400 "laterals", 401 exp.Lateral( 402 this=_udtf_type(unnest, has_multi_expr)(this=e), 403 view=True, 404 alias=exp.TableAlias( 405 this=alias.this, # type: ignore 406 columns=alias_cols, 407 ), 408 ), 409 ) 410 411 return expression
Convert cross join unnest into lateral view explode.
414def explode_projection_to_unnest( 415 index_offset: int = 0, 416) -> t.Callable[[exp.Expression], exp.Expression]: 417 """Convert explode/posexplode projections into unnests.""" 418 419 def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression: 420 if isinstance(expression, exp.Select): 421 from sqlglot.optimizer.scope import Scope 422 423 taken_select_names = set(expression.named_selects) 424 taken_source_names = {name for name, _ in Scope(expression).references} 425 426 def new_name(names: t.Set[str], name: str) -> str: 427 name = find_new_name(names, name) 428 names.add(name) 429 return name 430 431 arrays: t.List[exp.Condition] = [] 432 series_alias = new_name(taken_select_names, "pos") 433 series = exp.alias_( 434 exp.Unnest( 435 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 436 ), 437 new_name(taken_source_names, "_u"), 438 table=[series_alias], 439 ) 440 441 # we use list here because expression.selects is mutated inside the loop 442 for select in list(expression.selects): 443 explode = select.find(exp.Explode) 444 445 if explode: 446 pos_alias = "" 447 explode_alias = "" 448 449 if isinstance(select, exp.Alias): 450 explode_alias = select.args["alias"] 451 alias = select 452 elif isinstance(select, exp.Aliases): 453 pos_alias = select.aliases[0] 454 explode_alias = select.aliases[1] 455 alias = select.replace(exp.alias_(select.this, "", copy=False)) 456 else: 457 alias = select.replace(exp.alias_(select, "")) 458 explode = alias.find(exp.Explode) 459 assert explode 460 461 is_posexplode = isinstance(explode, exp.Posexplode) 462 explode_arg = explode.this 463 464 if isinstance(explode, exp.ExplodeOuter): 465 bracket = explode_arg[0] 466 bracket.set("safe", True) 467 bracket.set("offset", True) 468 explode_arg = exp.func( 469 "IF", 470 exp.func( 471 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 472 ).eq(0), 473 exp.array(bracket, copy=False), 474 explode_arg, 475 ) 476 477 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 478 if isinstance(explode_arg, exp.Column): 479 taken_select_names.add(explode_arg.output_name) 480 481 unnest_source_alias = new_name(taken_source_names, "_u") 482 483 if not explode_alias: 484 explode_alias = new_name(taken_select_names, "col") 485 486 if is_posexplode: 487 pos_alias = new_name(taken_select_names, "pos") 488 489 if not pos_alias: 490 pos_alias = new_name(taken_select_names, "pos") 491 492 alias.set("alias", exp.to_identifier(explode_alias)) 493 494 series_table_alias = series.args["alias"].this 495 column = exp.If( 496 this=exp.column(series_alias, table=series_table_alias).eq( 497 exp.column(pos_alias, table=unnest_source_alias) 498 ), 499 true=exp.column(explode_alias, table=unnest_source_alias), 500 ) 501 502 explode.replace(column) 503 504 if is_posexplode: 505 expressions = expression.expressions 506 expressions.insert( 507 expressions.index(alias) + 1, 508 exp.If( 509 this=exp.column(series_alias, table=series_table_alias).eq( 510 exp.column(pos_alias, table=unnest_source_alias) 511 ), 512 true=exp.column(pos_alias, table=unnest_source_alias), 513 ).as_(pos_alias), 514 ) 515 expression.set("expressions", expressions) 516 517 if not arrays: 518 if expression.args.get("from"): 519 expression.join(series, copy=False, join_type="CROSS") 520 else: 521 expression.from_(series, copy=False) 522 523 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 524 arrays.append(size) 525 526 # trino doesn't support left join unnest with on conditions 527 # if it did, this would be much simpler 528 expression.join( 529 exp.alias_( 530 exp.Unnest( 531 expressions=[explode_arg.copy()], 532 offset=exp.to_identifier(pos_alias), 533 ), 534 unnest_source_alias, 535 table=[explode_alias], 536 ), 537 join_type="CROSS", 538 copy=False, 539 ) 540 541 if index_offset != 1: 542 size = size - 1 543 544 expression.where( 545 exp.column(series_alias, table=series_table_alias) 546 .eq(exp.column(pos_alias, table=unnest_source_alias)) 547 .or_( 548 (exp.column(series_alias, table=series_table_alias) > size).and_( 549 exp.column(pos_alias, table=unnest_source_alias).eq(size) 550 ) 551 ), 552 copy=False, 553 ) 554 555 if arrays: 556 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 557 558 if index_offset != 1: 559 end = end - (1 - index_offset) 560 series.expressions[0].set("end", end) 561 562 return expression 563 564 return _explode_projection_to_unnest
Convert explode/posexplode projections into unnests.
567def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 568 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 569 if ( 570 isinstance(expression, exp.PERCENTILES) 571 and not isinstance(expression.parent, exp.WithinGroup) 572 and expression.expression 573 ): 574 column = expression.this.pop() 575 expression.set("this", expression.expression.pop()) 576 order = exp.Order(expressions=[exp.Ordered(this=column)]) 577 expression = exp.WithinGroup(this=expression, expression=order) 578 579 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
582def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 583 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 584 if ( 585 isinstance(expression, exp.WithinGroup) 586 and isinstance(expression.this, exp.PERCENTILES) 587 and isinstance(expression.expression, exp.Order) 588 ): 589 quantile = expression.this.this 590 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 591 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 592 593 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
596def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 597 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 598 if isinstance(expression, exp.With) and expression.recursive: 599 next_name = name_sequence("_c_") 600 601 for cte in expression.expressions: 602 if not cte.args["alias"].columns: 603 query = cte.this 604 if isinstance(query, exp.SetOperation): 605 query = query.this 606 607 cte.args["alias"].set( 608 "columns", 609 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 610 ) 611 612 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
615def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 616 """Replace 'epoch' in casts by the equivalent date literal.""" 617 if ( 618 isinstance(expression, (exp.Cast, exp.TryCast)) 619 and expression.name.lower() == "epoch" 620 and expression.to.this in exp.DataType.TEMPORAL_TYPES 621 ): 622 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 623 624 return expression
Replace 'epoch' in casts by the equivalent date literal.
627def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 628 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 629 if isinstance(expression, exp.Select): 630 for join in expression.args.get("joins") or []: 631 on = join.args.get("on") 632 if on and join.kind in ("SEMI", "ANTI"): 633 subquery = exp.select("1").from_(join.this).where(on) 634 exists = exp.Exists(this=subquery) 635 if join.kind == "ANTI": 636 exists = exists.not_(copy=False) 637 638 join.pop() 639 expression.where(exists, copy=False) 640 641 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
644def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 645 """ 646 Converts a query with a FULL OUTER join to a union of identical queries that 647 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 648 for queries that have a single FULL OUTER join. 649 """ 650 if isinstance(expression, exp.Select): 651 full_outer_joins = [ 652 (index, join) 653 for index, join in enumerate(expression.args.get("joins") or []) 654 if join.side == "FULL" 655 ] 656 657 if len(full_outer_joins) == 1: 658 expression_copy = expression.copy() 659 expression.set("limit", None) 660 index, full_outer_join = full_outer_joins[0] 661 662 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 663 join_conditions = full_outer_join.args.get("on") or exp.and_( 664 *[ 665 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 666 for col in full_outer_join.args.get("using") 667 ] 668 ) 669 670 full_outer_join.set("side", "left") 671 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 672 expression_copy.args["joins"][index].set("side", "right") 673 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 674 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 675 expression.args.pop("order", None) # remove order by from LEFT side 676 677 return exp.union(expression, expression_copy, copy=False, distinct=False) 678 679 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
682def move_ctes_to_top_level(expression: E) -> E: 683 """ 684 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 685 defined at the top-level, so for example queries like: 686 687 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 688 689 are invalid in those dialects. This transformation can be used to ensure all CTEs are 690 moved to the top level so that the final SQL code is valid from a syntax standpoint. 691 692 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 693 """ 694 top_level_with = expression.args.get("with") 695 for inner_with in expression.find_all(exp.With): 696 if inner_with.parent is expression: 697 continue 698 699 if not top_level_with: 700 top_level_with = inner_with.pop() 701 expression.set("with", top_level_with) 702 else: 703 if inner_with.recursive: 704 top_level_with.set("recursive", True) 705 706 parent_cte = inner_with.find_ancestor(exp.CTE) 707 inner_with.pop() 708 709 if parent_cte: 710 i = top_level_with.expressions.index(parent_cte) 711 top_level_with.expressions[i:i] = inner_with.expressions 712 top_level_with.set("expressions", top_level_with.expressions) 713 else: 714 top_level_with.set( 715 "expressions", top_level_with.expressions + inner_with.expressions 716 ) 717 718 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
721def ensure_bools(expression: exp.Expression) -> exp.Expression: 722 """Converts numeric values used in conditions into explicit boolean expressions.""" 723 from sqlglot.optimizer.canonicalize import ensure_bools 724 725 def _ensure_bool(node: exp.Expression) -> None: 726 if ( 727 node.is_number 728 or ( 729 not isinstance(node, exp.SubqueryPredicate) 730 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 731 ) 732 or (isinstance(node, exp.Column) and not node.type) 733 ): 734 node.replace(node.neq(0)) 735 736 for node in expression.walk(): 737 ensure_bools(node, _ensure_bool) 738 739 return expression
Converts numeric values used in conditions into explicit boolean expressions.
760def ctas_with_tmp_tables_to_create_tmp_view( 761 expression: exp.Expression, 762 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 763) -> exp.Expression: 764 assert isinstance(expression, exp.Create) 765 properties = expression.args.get("properties") 766 temporary = any( 767 isinstance(prop, exp.TemporaryProperty) 768 for prop in (properties.expressions if properties else []) 769 ) 770 771 # CTAS with temp tables map to CREATE TEMPORARY VIEW 772 if expression.kind == "TABLE" and temporary: 773 if expression.expression: 774 return exp.Create( 775 kind="TEMPORARY VIEW", 776 this=expression.this, 777 expression=expression.expression, 778 ) 779 return tmp_storage_provider(expression) 780 781 return expression
784def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 785 """ 786 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 787 PARTITIONED BY value is an array of column names, they are transformed into a schema. 788 The corresponding columns are removed from the create statement. 789 """ 790 assert isinstance(expression, exp.Create) 791 has_schema = isinstance(expression.this, exp.Schema) 792 is_partitionable = expression.kind in {"TABLE", "VIEW"} 793 794 if has_schema and is_partitionable: 795 prop = expression.find(exp.PartitionedByProperty) 796 if prop and prop.this and not isinstance(prop.this, exp.Schema): 797 schema = expression.this 798 columns = {v.name.upper() for v in prop.this.expressions} 799 partitions = [col for col in schema.expressions if col.name.upper() in columns] 800 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 801 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 802 expression.set("this", schema) 803 804 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
807def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 808 """ 809 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 810 811 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 812 """ 813 assert isinstance(expression, exp.Create) 814 prop = expression.find(exp.PartitionedByProperty) 815 if ( 816 prop 817 and prop.this 818 and isinstance(prop.this, exp.Schema) 819 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 820 ): 821 prop_this = exp.Tuple( 822 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 823 ) 824 schema = expression.this 825 for e in prop.this.expressions: 826 schema.append("expressions", e) 827 prop.set("this", prop_this) 828 829 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
832def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 833 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 834 if isinstance(expression, exp.Struct): 835 expression.set( 836 "expressions", 837 [ 838 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 839 for e in expression.expressions 840 ], 841 ) 842 843 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
846def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 847 """ 848 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 849 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 850 851 For example, 852 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 853 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 854 855 Args: 856 expression: The AST to remove join marks from. 857 858 Returns: 859 The AST with join marks removed. 860 """ 861 from sqlglot.optimizer.scope import traverse_scope 862 863 for scope in traverse_scope(expression): 864 query = scope.expression 865 866 where = query.args.get("where") 867 joins = query.args.get("joins") 868 869 if not where or not joins: 870 continue 871 872 query_from = query.args["from"] 873 874 # These keep track of the joins to be replaced 875 new_joins: t.Dict[str, exp.Join] = {} 876 old_joins = {join.alias_or_name: join for join in joins} 877 878 for column in scope.columns: 879 if not column.args.get("join_mark"): 880 continue 881 882 predicate = column.find_ancestor(exp.Predicate, exp.Select) 883 assert isinstance( 884 predicate, exp.Binary 885 ), "Columns can only be marked with (+) when involved in a binary operation" 886 887 predicate_parent = predicate.parent 888 join_predicate = predicate.pop() 889 890 left_columns = [ 891 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 892 ] 893 right_columns = [ 894 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 895 ] 896 897 assert not ( 898 left_columns and right_columns 899 ), "The (+) marker cannot appear in both sides of a binary predicate" 900 901 marked_column_tables = set() 902 for col in left_columns or right_columns: 903 table = col.table 904 assert table, f"Column {col} needs to be qualified with a table" 905 906 col.set("join_mark", False) 907 marked_column_tables.add(table) 908 909 assert ( 910 len(marked_column_tables) == 1 911 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 912 913 # Add predicate if join already copied, or add join if it is new 914 join_this = old_joins.get(col.table, query_from).this 915 existing_join = new_joins.get(join_this.alias_or_name) 916 if existing_join: 917 existing_join.set("on", exp.and_(existing_join.args["on"], join_predicate)) 918 else: 919 new_joins[join_this.alias_or_name] = exp.Join( 920 this=join_this.copy(), on=join_predicate.copy(), kind="LEFT" 921 ) 922 923 # If the parent of the target predicate is a binary node, then it now has only one child 924 if isinstance(predicate_parent, exp.Binary): 925 if predicate_parent.left is None: 926 predicate_parent.replace(predicate_parent.right) 927 else: 928 predicate_parent.replace(predicate_parent.left) 929 930 if query_from.alias_or_name in new_joins: 931 only_old_joins = old_joins.keys() - new_joins.keys() 932 assert ( 933 len(only_old_joins) >= 1 934 ), "Cannot determine which table to use in the new FROM clause" 935 936 new_from_name = list(only_old_joins)[0] 937 query.set("from", exp.From(this=old_joins[new_from_name].this)) 938 939 if new_joins: 940 query.set("joins", list(new_joins.values())) 941 942 if not where.this: 943 where.pop() 944 945 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.
948def any_to_exists(expression: exp.Expression) -> exp.Expression: 949 """ 950 Transform ANY operator to Spark's EXISTS 951 952 For example, 953 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 954 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 955 956 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 957 transformation 958 """ 959 if isinstance(expression, exp.Select): 960 for any in expression.find_all(exp.Any): 961 this = any.this 962 if isinstance(this, exp.Query): 963 continue 964 965 binop = any.parent 966 if isinstance(binop, exp.Binary): 967 lambda_arg = exp.to_identifier("x") 968 any.replace(lambda_arg) 969 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 970 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 971 972 return expression
Transform ANY operator to Spark's EXISTS
For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation