sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot.generator import Generator 12 13 14def preprocess( 15 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 16) -> t.Callable[[Generator, exp.Expression], str]: 17 """ 18 Creates a new transform by chaining a sequence of transformations and converts the resulting 19 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 20 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 21 22 Args: 23 transforms: sequence of transform functions. These will be called in order. 24 25 Returns: 26 Function that can be used as a generator transform. 27 """ 28 29 def _to_sql(self, expression: exp.Expression) -> str: 30 expression_type = type(expression) 31 32 try: 33 expression = transforms[0](expression) 34 for transform in transforms[1:]: 35 expression = transform(expression) 36 except UnsupportedError as unsupported_error: 37 self.unsupported(str(unsupported_error)) 38 39 _sql_handler = getattr(self, expression.key + "_sql", None) 40 if _sql_handler: 41 return _sql_handler(expression) 42 43 transforms_handler = self.TRANSFORMS.get(type(expression)) 44 if transforms_handler: 45 if expression_type is type(expression): 46 if isinstance(expression, exp.Func): 47 return self.function_fallback_sql(expression) 48 49 # Ensures we don't enter an infinite loop. This can happen when the original expression 50 # has the same type as the final expression and there's no _sql method available for it, 51 # because then it'd re-enter _to_sql. 52 raise ValueError( 53 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 54 ) 55 56 return transforms_handler(self, expression) 57 58 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 59 60 return _to_sql 61 62 63def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 64 if isinstance(expression, exp.Select): 65 count = 0 66 recursive_ctes = [] 67 68 for unnest in expression.find_all(exp.Unnest): 69 if ( 70 not isinstance(unnest.parent, (exp.From, exp.Join)) 71 or len(unnest.expressions) != 1 72 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 73 ): 74 continue 75 76 generate_date_array = unnest.expressions[0] 77 start = generate_date_array.args.get("start") 78 end = generate_date_array.args.get("end") 79 step = generate_date_array.args.get("step") 80 81 if not start or not end or not isinstance(step, exp.Interval): 82 continue 83 84 alias = unnest.args.get("alias") 85 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 86 87 start = exp.cast(start, "date") 88 date_add = exp.func( 89 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 90 ) 91 cast_date_add = exp.cast(date_add, "date") 92 93 cte_name = "_generated_dates" + (f"_{count}" if count else "") 94 95 base_query = exp.select(start.as_(column_name)) 96 recursive_query = ( 97 exp.select(cast_date_add) 98 .from_(cte_name) 99 .where(cast_date_add <= exp.cast(end, "date")) 100 ) 101 cte_query = base_query.union(recursive_query, distinct=False) 102 103 generate_dates_query = exp.select(column_name).from_(cte_name) 104 unnest.replace(generate_dates_query.subquery(cte_name)) 105 106 recursive_ctes.append( 107 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 108 ) 109 count += 1 110 111 if recursive_ctes: 112 with_expression = expression.args.get("with") or exp.With() 113 with_expression.set("recursive", True) 114 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 115 expression.set("with", with_expression) 116 117 return expression 118 119 120def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 121 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 122 this = expression.this 123 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 124 unnest = exp.Unnest(expressions=[this]) 125 if expression.alias: 126 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 127 128 return unnest 129 130 return expression 131 132 133def unalias_group(expression: exp.Expression) -> exp.Expression: 134 """ 135 Replace references to select aliases in GROUP BY clauses. 136 137 Example: 138 >>> import sqlglot 139 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 140 'SELECT a AS b FROM x GROUP BY 1' 141 142 Args: 143 expression: the expression that will be transformed. 144 145 Returns: 146 The transformed expression. 147 """ 148 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 149 aliased_selects = { 150 e.alias: i 151 for i, e in enumerate(expression.parent.expressions, start=1) 152 if isinstance(e, exp.Alias) 153 } 154 155 for group_by in expression.expressions: 156 if ( 157 isinstance(group_by, exp.Column) 158 and not group_by.table 159 and group_by.name in aliased_selects 160 ): 161 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 162 163 return expression 164 165 166def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 167 """ 168 Convert SELECT DISTINCT ON statements to a subquery with a window function. 169 170 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 171 172 Args: 173 expression: the expression that will be transformed. 174 175 Returns: 176 The transformed expression. 177 """ 178 if ( 179 isinstance(expression, exp.Select) 180 and expression.args.get("distinct") 181 and expression.args["distinct"].args.get("on") 182 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 183 ): 184 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 185 outer_selects = expression.selects 186 row_number = find_new_name(expression.named_selects, "_row_number") 187 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 188 order = expression.args.get("order") 189 190 if order: 191 window.set("order", order.pop()) 192 else: 193 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 194 195 window = exp.alias_(window, row_number) 196 expression.select(window, copy=False) 197 198 return ( 199 exp.select(*outer_selects, copy=False) 200 .from_(expression.subquery("_t", copy=False), copy=False) 201 .where(exp.column(row_number).eq(1), copy=False) 202 ) 203 204 return expression 205 206 207def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 208 """ 209 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 210 211 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 212 https://docs.snowflake.com/en/sql-reference/constructs/qualify 213 214 Some dialects don't support window functions in the WHERE clause, so we need to include them as 215 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 216 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 217 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 218 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 219 corresponding expression to avoid creating invalid column references. 220 """ 221 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 222 taken = set(expression.named_selects) 223 for select in expression.selects: 224 if not select.alias_or_name: 225 alias = find_new_name(taken, "_c") 226 select.replace(exp.alias_(select, alias)) 227 taken.add(alias) 228 229 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 230 alias_or_name = select.alias_or_name 231 identifier = select.args.get("alias") or select.this 232 if isinstance(identifier, exp.Identifier): 233 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 234 return alias_or_name 235 236 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 237 qualify_filters = expression.args["qualify"].pop().this 238 expression_by_alias = { 239 select.alias: select.this 240 for select in expression.selects 241 if isinstance(select, exp.Alias) 242 } 243 244 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 245 for select_candidate in qualify_filters.find_all(select_candidates): 246 if isinstance(select_candidate, exp.Window): 247 if expression_by_alias: 248 for column in select_candidate.find_all(exp.Column): 249 expr = expression_by_alias.get(column.name) 250 if expr: 251 column.replace(expr) 252 253 alias = find_new_name(expression.named_selects, "_w") 254 expression.select(exp.alias_(select_candidate, alias), copy=False) 255 column = exp.column(alias) 256 257 if isinstance(select_candidate.parent, exp.Qualify): 258 qualify_filters = column 259 else: 260 select_candidate.replace(column) 261 elif select_candidate.name not in expression.named_selects: 262 expression.select(select_candidate.copy(), copy=False) 263 264 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 265 qualify_filters, copy=False 266 ) 267 268 return expression 269 270 271def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 272 """ 273 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 274 other expressions. This transforms removes the precision from parameterized types in expressions. 275 """ 276 for node in expression.find_all(exp.DataType): 277 node.set( 278 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 279 ) 280 281 return expression 282 283 284def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 285 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 286 from sqlglot.optimizer.scope import find_all_in_scope 287 288 if isinstance(expression, exp.Select): 289 unnest_aliases = { 290 unnest.alias 291 for unnest in find_all_in_scope(expression, exp.Unnest) 292 if isinstance(unnest.parent, (exp.From, exp.Join)) 293 } 294 if unnest_aliases: 295 for column in expression.find_all(exp.Column): 296 if column.table in unnest_aliases: 297 column.set("table", None) 298 elif column.db in unnest_aliases: 299 column.set("db", None) 300 301 return expression 302 303 304def unnest_to_explode( 305 expression: exp.Expression, 306 unnest_using_arrays_zip: bool = True, 307) -> exp.Expression: 308 """Convert cross join unnest into lateral view explode.""" 309 310 def _unnest_zip_exprs( 311 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 312 ) -> t.List[exp.Expression]: 313 if has_multi_expr: 314 if not unnest_using_arrays_zip: 315 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 316 317 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 318 zip_exprs: t.List[exp.Expression] = [ 319 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 320 ] 321 u.set("expressions", zip_exprs) 322 return zip_exprs 323 return unnest_exprs 324 325 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 326 if u.args.get("offset"): 327 return exp.Posexplode 328 return exp.Inline if has_multi_expr else exp.Explode 329 330 if isinstance(expression, exp.Select): 331 from_ = expression.args.get("from") 332 333 if from_ and isinstance(from_.this, exp.Unnest): 334 unnest = from_.this 335 alias = unnest.args.get("alias") 336 exprs = unnest.expressions 337 has_multi_expr = len(exprs) > 1 338 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 339 340 unnest.replace( 341 exp.Table( 342 this=_udtf_type(unnest, has_multi_expr)( 343 this=this, 344 expressions=expressions, 345 ), 346 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 347 ) 348 ) 349 350 for join in expression.args.get("joins") or []: 351 join_expr = join.this 352 353 is_lateral = isinstance(join_expr, exp.Lateral) 354 355 unnest = join_expr.this if is_lateral else join_expr 356 357 if isinstance(unnest, exp.Unnest): 358 if is_lateral: 359 alias = join_expr.args.get("alias") 360 else: 361 alias = unnest.args.get("alias") 362 exprs = unnest.expressions 363 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 364 has_multi_expr = len(exprs) > 1 365 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 366 367 expression.args["joins"].remove(join) 368 369 alias_cols = alias.columns if alias else [] 370 for e, column in zip(exprs, alias_cols): 371 expression.append( 372 "laterals", 373 exp.Lateral( 374 this=_udtf_type(unnest, has_multi_expr)(this=e), 375 view=True, 376 alias=exp.TableAlias( 377 this=alias.this, # type: ignore 378 columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore 379 ), 380 ), 381 ) 382 383 return expression 384 385 386def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 387 """Convert explode/posexplode into unnest.""" 388 389 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 390 if isinstance(expression, exp.Select): 391 from sqlglot.optimizer.scope import Scope 392 393 taken_select_names = set(expression.named_selects) 394 taken_source_names = {name for name, _ in Scope(expression).references} 395 396 def new_name(names: t.Set[str], name: str) -> str: 397 name = find_new_name(names, name) 398 names.add(name) 399 return name 400 401 arrays: t.List[exp.Condition] = [] 402 series_alias = new_name(taken_select_names, "pos") 403 series = exp.alias_( 404 exp.Unnest( 405 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 406 ), 407 new_name(taken_source_names, "_u"), 408 table=[series_alias], 409 ) 410 411 # we use list here because expression.selects is mutated inside the loop 412 for select in list(expression.selects): 413 explode = select.find(exp.Explode) 414 415 if explode: 416 pos_alias = "" 417 explode_alias = "" 418 419 if isinstance(select, exp.Alias): 420 explode_alias = select.args["alias"] 421 alias = select 422 elif isinstance(select, exp.Aliases): 423 pos_alias = select.aliases[0] 424 explode_alias = select.aliases[1] 425 alias = select.replace(exp.alias_(select.this, "", copy=False)) 426 else: 427 alias = select.replace(exp.alias_(select, "")) 428 explode = alias.find(exp.Explode) 429 assert explode 430 431 is_posexplode = isinstance(explode, exp.Posexplode) 432 explode_arg = explode.this 433 434 if isinstance(explode, exp.ExplodeOuter): 435 bracket = explode_arg[0] 436 bracket.set("safe", True) 437 bracket.set("offset", True) 438 explode_arg = exp.func( 439 "IF", 440 exp.func( 441 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 442 ).eq(0), 443 exp.array(bracket, copy=False), 444 explode_arg, 445 ) 446 447 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 448 if isinstance(explode_arg, exp.Column): 449 taken_select_names.add(explode_arg.output_name) 450 451 unnest_source_alias = new_name(taken_source_names, "_u") 452 453 if not explode_alias: 454 explode_alias = new_name(taken_select_names, "col") 455 456 if is_posexplode: 457 pos_alias = new_name(taken_select_names, "pos") 458 459 if not pos_alias: 460 pos_alias = new_name(taken_select_names, "pos") 461 462 alias.set("alias", exp.to_identifier(explode_alias)) 463 464 series_table_alias = series.args["alias"].this 465 column = exp.If( 466 this=exp.column(series_alias, table=series_table_alias).eq( 467 exp.column(pos_alias, table=unnest_source_alias) 468 ), 469 true=exp.column(explode_alias, table=unnest_source_alias), 470 ) 471 472 explode.replace(column) 473 474 if is_posexplode: 475 expressions = expression.expressions 476 expressions.insert( 477 expressions.index(alias) + 1, 478 exp.If( 479 this=exp.column(series_alias, table=series_table_alias).eq( 480 exp.column(pos_alias, table=unnest_source_alias) 481 ), 482 true=exp.column(pos_alias, table=unnest_source_alias), 483 ).as_(pos_alias), 484 ) 485 expression.set("expressions", expressions) 486 487 if not arrays: 488 if expression.args.get("from"): 489 expression.join(series, copy=False, join_type="CROSS") 490 else: 491 expression.from_(series, copy=False) 492 493 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 494 arrays.append(size) 495 496 # trino doesn't support left join unnest with on conditions 497 # if it did, this would be much simpler 498 expression.join( 499 exp.alias_( 500 exp.Unnest( 501 expressions=[explode_arg.copy()], 502 offset=exp.to_identifier(pos_alias), 503 ), 504 unnest_source_alias, 505 table=[explode_alias], 506 ), 507 join_type="CROSS", 508 copy=False, 509 ) 510 511 if index_offset != 1: 512 size = size - 1 513 514 expression.where( 515 exp.column(series_alias, table=series_table_alias) 516 .eq(exp.column(pos_alias, table=unnest_source_alias)) 517 .or_( 518 (exp.column(series_alias, table=series_table_alias) > size).and_( 519 exp.column(pos_alias, table=unnest_source_alias).eq(size) 520 ) 521 ), 522 copy=False, 523 ) 524 525 if arrays: 526 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 527 528 if index_offset != 1: 529 end = end - (1 - index_offset) 530 series.expressions[0].set("end", end) 531 532 return expression 533 534 return _explode_to_unnest 535 536 537def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 538 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 539 if ( 540 isinstance(expression, exp.PERCENTILES) 541 and not isinstance(expression.parent, exp.WithinGroup) 542 and expression.expression 543 ): 544 column = expression.this.pop() 545 expression.set("this", expression.expression.pop()) 546 order = exp.Order(expressions=[exp.Ordered(this=column)]) 547 expression = exp.WithinGroup(this=expression, expression=order) 548 549 return expression 550 551 552def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 553 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 554 if ( 555 isinstance(expression, exp.WithinGroup) 556 and isinstance(expression.this, exp.PERCENTILES) 557 and isinstance(expression.expression, exp.Order) 558 ): 559 quantile = expression.this.this 560 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 561 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 562 563 return expression 564 565 566def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 567 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 568 if isinstance(expression, exp.With) and expression.recursive: 569 next_name = name_sequence("_c_") 570 571 for cte in expression.expressions: 572 if not cte.args["alias"].columns: 573 query = cte.this 574 if isinstance(query, exp.SetOperation): 575 query = query.this 576 577 cte.args["alias"].set( 578 "columns", 579 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 580 ) 581 582 return expression 583 584 585def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 586 """Replace 'epoch' in casts by the equivalent date literal.""" 587 if ( 588 isinstance(expression, (exp.Cast, exp.TryCast)) 589 and expression.name.lower() == "epoch" 590 and expression.to.this in exp.DataType.TEMPORAL_TYPES 591 ): 592 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 593 594 return expression 595 596 597def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 598 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 599 if isinstance(expression, exp.Select): 600 for join in expression.args.get("joins") or []: 601 on = join.args.get("on") 602 if on and join.kind in ("SEMI", "ANTI"): 603 subquery = exp.select("1").from_(join.this).where(on) 604 exists = exp.Exists(this=subquery) 605 if join.kind == "ANTI": 606 exists = exists.not_(copy=False) 607 608 join.pop() 609 expression.where(exists, copy=False) 610 611 return expression 612 613 614def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 615 """ 616 Converts a query with a FULL OUTER join to a union of identical queries that 617 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 618 for queries that have a single FULL OUTER join. 619 """ 620 if isinstance(expression, exp.Select): 621 full_outer_joins = [ 622 (index, join) 623 for index, join in enumerate(expression.args.get("joins") or []) 624 if join.side == "FULL" 625 ] 626 627 if len(full_outer_joins) == 1: 628 expression_copy = expression.copy() 629 expression.set("limit", None) 630 index, full_outer_join = full_outer_joins[0] 631 full_outer_join.set("side", "left") 632 expression_copy.args["joins"][index].set("side", "right") 633 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 634 635 return exp.union(expression, expression_copy, copy=False) 636 637 return expression 638 639 640def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 641 """ 642 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 643 defined at the top-level, so for example queries like: 644 645 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 646 647 are invalid in those dialects. This transformation can be used to ensure all CTEs are 648 moved to the top level so that the final SQL code is valid from a syntax standpoint. 649 650 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 651 """ 652 top_level_with = expression.args.get("with") 653 for inner_with in expression.find_all(exp.With): 654 if inner_with.parent is expression: 655 continue 656 657 if not top_level_with: 658 top_level_with = inner_with.pop() 659 expression.set("with", top_level_with) 660 else: 661 if inner_with.recursive: 662 top_level_with.set("recursive", True) 663 664 parent_cte = inner_with.find_ancestor(exp.CTE) 665 inner_with.pop() 666 667 if parent_cte: 668 i = top_level_with.expressions.index(parent_cte) 669 top_level_with.expressions[i:i] = inner_with.expressions 670 top_level_with.set("expressions", top_level_with.expressions) 671 else: 672 top_level_with.set( 673 "expressions", top_level_with.expressions + inner_with.expressions 674 ) 675 676 return expression 677 678 679def ensure_bools(expression: exp.Expression) -> exp.Expression: 680 """Converts numeric values used in conditions into explicit boolean expressions.""" 681 from sqlglot.optimizer.canonicalize import ensure_bools 682 683 def _ensure_bool(node: exp.Expression) -> None: 684 if ( 685 node.is_number 686 or ( 687 not isinstance(node, exp.SubqueryPredicate) 688 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 689 ) 690 or (isinstance(node, exp.Column) and not node.type) 691 ): 692 node.replace(node.neq(0)) 693 694 for node in expression.walk(): 695 ensure_bools(node, _ensure_bool) 696 697 return expression 698 699 700def unqualify_columns(expression: exp.Expression) -> exp.Expression: 701 for column in expression.find_all(exp.Column): 702 # We only wanna pop off the table, db, catalog args 703 for part in column.parts[:-1]: 704 part.pop() 705 706 return expression 707 708 709def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 710 assert isinstance(expression, exp.Create) 711 for constraint in expression.find_all(exp.UniqueColumnConstraint): 712 if constraint.parent: 713 constraint.parent.pop() 714 715 return expression 716 717 718def ctas_with_tmp_tables_to_create_tmp_view( 719 expression: exp.Expression, 720 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 721) -> exp.Expression: 722 assert isinstance(expression, exp.Create) 723 properties = expression.args.get("properties") 724 temporary = any( 725 isinstance(prop, exp.TemporaryProperty) 726 for prop in (properties.expressions if properties else []) 727 ) 728 729 # CTAS with temp tables map to CREATE TEMPORARY VIEW 730 if expression.kind == "TABLE" and temporary: 731 if expression.expression: 732 return exp.Create( 733 kind="TEMPORARY VIEW", 734 this=expression.this, 735 expression=expression.expression, 736 ) 737 return tmp_storage_provider(expression) 738 739 return expression 740 741 742def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 743 """ 744 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 745 PARTITIONED BY value is an array of column names, they are transformed into a schema. 746 The corresponding columns are removed from the create statement. 747 """ 748 assert isinstance(expression, exp.Create) 749 has_schema = isinstance(expression.this, exp.Schema) 750 is_partitionable = expression.kind in {"TABLE", "VIEW"} 751 752 if has_schema and is_partitionable: 753 prop = expression.find(exp.PartitionedByProperty) 754 if prop and prop.this and not isinstance(prop.this, exp.Schema): 755 schema = expression.this 756 columns = {v.name.upper() for v in prop.this.expressions} 757 partitions = [col for col in schema.expressions if col.name.upper() in columns] 758 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 759 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 760 expression.set("this", schema) 761 762 return expression 763 764 765def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 766 """ 767 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 768 769 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 770 """ 771 assert isinstance(expression, exp.Create) 772 prop = expression.find(exp.PartitionedByProperty) 773 if ( 774 prop 775 and prop.this 776 and isinstance(prop.this, exp.Schema) 777 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 778 ): 779 prop_this = exp.Tuple( 780 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 781 ) 782 schema = expression.this 783 for e in prop.this.expressions: 784 schema.append("expressions", e) 785 prop.set("this", prop_this) 786 787 return expression 788 789 790def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 791 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 792 if isinstance(expression, exp.Struct): 793 expression.set( 794 "expressions", 795 [ 796 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 797 for e in expression.expressions 798 ], 799 ) 800 801 return expression 802 803 804def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 805 """ 806 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 807 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 808 809 For example, 810 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 811 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 812 813 Args: 814 expression: The AST to remove join marks from. 815 816 Returns: 817 The AST with join marks removed. 818 """ 819 from sqlglot.optimizer.scope import traverse_scope 820 821 for scope in traverse_scope(expression): 822 query = scope.expression 823 824 where = query.args.get("where") 825 joins = query.args.get("joins") 826 827 if not where or not joins: 828 continue 829 830 query_from = query.args["from"] 831 832 # These keep track of the joins to be replaced 833 new_joins: t.Dict[str, exp.Join] = {} 834 old_joins = {join.alias_or_name: join for join in joins} 835 836 for column in scope.columns: 837 if not column.args.get("join_mark"): 838 continue 839 840 predicate = column.find_ancestor(exp.Predicate, exp.Select) 841 assert isinstance( 842 predicate, exp.Binary 843 ), "Columns can only be marked with (+) when involved in a binary operation" 844 845 predicate_parent = predicate.parent 846 join_predicate = predicate.pop() 847 848 left_columns = [ 849 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 850 ] 851 right_columns = [ 852 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 853 ] 854 855 assert not ( 856 left_columns and right_columns 857 ), "The (+) marker cannot appear in both sides of a binary predicate" 858 859 marked_column_tables = set() 860 for col in left_columns or right_columns: 861 table = col.table 862 assert table, f"Column {col} needs to be qualified with a table" 863 864 col.set("join_mark", False) 865 marked_column_tables.add(table) 866 867 assert ( 868 len(marked_column_tables) == 1 869 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 870 871 join_this = old_joins.get(col.table, query_from).this 872 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 873 874 # Upsert new_join into new_joins dictionary 875 new_join_alias_or_name = new_join.alias_or_name 876 existing_join = new_joins.get(new_join_alias_or_name) 877 if existing_join: 878 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 879 else: 880 new_joins[new_join_alias_or_name] = new_join 881 882 # If the parent of the target predicate is a binary node, then it now has only one child 883 if isinstance(predicate_parent, exp.Binary): 884 if predicate_parent.left is None: 885 predicate_parent.replace(predicate_parent.right) 886 else: 887 predicate_parent.replace(predicate_parent.left) 888 889 if query_from.alias_or_name in new_joins: 890 only_old_joins = old_joins.keys() - new_joins.keys() 891 assert ( 892 len(only_old_joins) >= 1 893 ), "Cannot determine which table to use in the new FROM clause" 894 895 new_from_name = list(only_old_joins)[0] 896 query.set("from", exp.From(this=old_joins[new_from_name].this)) 897 898 query.set("joins", list(new_joins.values())) 899 900 if not where.this: 901 where.pop() 902 903 return expression
15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
134def unalias_group(expression: exp.Expression) -> exp.Expression: 135 """ 136 Replace references to select aliases in GROUP BY clauses. 137 138 Example: 139 >>> import sqlglot 140 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 141 'SELECT a AS b FROM x GROUP BY 1' 142 143 Args: 144 expression: the expression that will be transformed. 145 146 Returns: 147 The transformed expression. 148 """ 149 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 150 aliased_selects = { 151 e.alias: i 152 for i, e in enumerate(expression.parent.expressions, start=1) 153 if isinstance(e, exp.Alias) 154 } 155 156 for group_by in expression.expressions: 157 if ( 158 isinstance(group_by, exp.Column) 159 and not group_by.table 160 and group_by.name in aliased_selects 161 ): 162 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 163 164 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 168 """ 169 Convert SELECT DISTINCT ON statements to a subquery with a window function. 170 171 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 172 173 Args: 174 expression: the expression that will be transformed. 175 176 Returns: 177 The transformed expression. 178 """ 179 if ( 180 isinstance(expression, exp.Select) 181 and expression.args.get("distinct") 182 and expression.args["distinct"].args.get("on") 183 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 184 ): 185 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 186 outer_selects = expression.selects 187 row_number = find_new_name(expression.named_selects, "_row_number") 188 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 189 order = expression.args.get("order") 190 191 if order: 192 window.set("order", order.pop()) 193 else: 194 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 195 196 window = exp.alias_(window, row_number) 197 expression.select(window, copy=False) 198 199 return ( 200 exp.select(*outer_selects, copy=False) 201 .from_(expression.subquery("_t", copy=False), copy=False) 202 .where(exp.column(row_number).eq(1), copy=False) 203 ) 204 205 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
208def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 209 """ 210 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 211 212 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 213 https://docs.snowflake.com/en/sql-reference/constructs/qualify 214 215 Some dialects don't support window functions in the WHERE clause, so we need to include them as 216 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 217 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 218 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 219 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 220 corresponding expression to avoid creating invalid column references. 221 """ 222 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 223 taken = set(expression.named_selects) 224 for select in expression.selects: 225 if not select.alias_or_name: 226 alias = find_new_name(taken, "_c") 227 select.replace(exp.alias_(select, alias)) 228 taken.add(alias) 229 230 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 231 alias_or_name = select.alias_or_name 232 identifier = select.args.get("alias") or select.this 233 if isinstance(identifier, exp.Identifier): 234 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 235 return alias_or_name 236 237 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 238 qualify_filters = expression.args["qualify"].pop().this 239 expression_by_alias = { 240 select.alias: select.this 241 for select in expression.selects 242 if isinstance(select, exp.Alias) 243 } 244 245 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 246 for select_candidate in qualify_filters.find_all(select_candidates): 247 if isinstance(select_candidate, exp.Window): 248 if expression_by_alias: 249 for column in select_candidate.find_all(exp.Column): 250 expr = expression_by_alias.get(column.name) 251 if expr: 252 column.replace(expr) 253 254 alias = find_new_name(expression.named_selects, "_w") 255 expression.select(exp.alias_(select_candidate, alias), copy=False) 256 column = exp.column(alias) 257 258 if isinstance(select_candidate.parent, exp.Qualify): 259 qualify_filters = column 260 else: 261 select_candidate.replace(column) 262 elif select_candidate.name not in expression.named_selects: 263 expression.select(select_candidate.copy(), copy=False) 264 265 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 266 qualify_filters, copy=False 267 ) 268 269 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 273 """ 274 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 275 other expressions. This transforms removes the precision from parameterized types in expressions. 276 """ 277 for node in expression.find_all(exp.DataType): 278 node.set( 279 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 280 ) 281 282 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
285def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 286 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 287 from sqlglot.optimizer.scope import find_all_in_scope 288 289 if isinstance(expression, exp.Select): 290 unnest_aliases = { 291 unnest.alias 292 for unnest in find_all_in_scope(expression, exp.Unnest) 293 if isinstance(unnest.parent, (exp.From, exp.Join)) 294 } 295 if unnest_aliases: 296 for column in expression.find_all(exp.Column): 297 if column.table in unnest_aliases: 298 column.set("table", None) 299 elif column.db in unnest_aliases: 300 column.set("db", None) 301 302 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
305def unnest_to_explode( 306 expression: exp.Expression, 307 unnest_using_arrays_zip: bool = True, 308) -> exp.Expression: 309 """Convert cross join unnest into lateral view explode.""" 310 311 def _unnest_zip_exprs( 312 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 313 ) -> t.List[exp.Expression]: 314 if has_multi_expr: 315 if not unnest_using_arrays_zip: 316 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 317 318 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 319 zip_exprs: t.List[exp.Expression] = [ 320 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 321 ] 322 u.set("expressions", zip_exprs) 323 return zip_exprs 324 return unnest_exprs 325 326 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 327 if u.args.get("offset"): 328 return exp.Posexplode 329 return exp.Inline if has_multi_expr else exp.Explode 330 331 if isinstance(expression, exp.Select): 332 from_ = expression.args.get("from") 333 334 if from_ and isinstance(from_.this, exp.Unnest): 335 unnest = from_.this 336 alias = unnest.args.get("alias") 337 exprs = unnest.expressions 338 has_multi_expr = len(exprs) > 1 339 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 340 341 unnest.replace( 342 exp.Table( 343 this=_udtf_type(unnest, has_multi_expr)( 344 this=this, 345 expressions=expressions, 346 ), 347 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 348 ) 349 ) 350 351 for join in expression.args.get("joins") or []: 352 join_expr = join.this 353 354 is_lateral = isinstance(join_expr, exp.Lateral) 355 356 unnest = join_expr.this if is_lateral else join_expr 357 358 if isinstance(unnest, exp.Unnest): 359 if is_lateral: 360 alias = join_expr.args.get("alias") 361 else: 362 alias = unnest.args.get("alias") 363 exprs = unnest.expressions 364 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 365 has_multi_expr = len(exprs) > 1 366 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 367 368 expression.args["joins"].remove(join) 369 370 alias_cols = alias.columns if alias else [] 371 for e, column in zip(exprs, alias_cols): 372 expression.append( 373 "laterals", 374 exp.Lateral( 375 this=_udtf_type(unnest, has_multi_expr)(this=e), 376 view=True, 377 alias=exp.TableAlias( 378 this=alias.this, # type: ignore 379 columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore 380 ), 381 ), 382 ) 383 384 return expression
Convert cross join unnest into lateral view explode.
387def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 388 """Convert explode/posexplode into unnest.""" 389 390 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 391 if isinstance(expression, exp.Select): 392 from sqlglot.optimizer.scope import Scope 393 394 taken_select_names = set(expression.named_selects) 395 taken_source_names = {name for name, _ in Scope(expression).references} 396 397 def new_name(names: t.Set[str], name: str) -> str: 398 name = find_new_name(names, name) 399 names.add(name) 400 return name 401 402 arrays: t.List[exp.Condition] = [] 403 series_alias = new_name(taken_select_names, "pos") 404 series = exp.alias_( 405 exp.Unnest( 406 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 407 ), 408 new_name(taken_source_names, "_u"), 409 table=[series_alias], 410 ) 411 412 # we use list here because expression.selects is mutated inside the loop 413 for select in list(expression.selects): 414 explode = select.find(exp.Explode) 415 416 if explode: 417 pos_alias = "" 418 explode_alias = "" 419 420 if isinstance(select, exp.Alias): 421 explode_alias = select.args["alias"] 422 alias = select 423 elif isinstance(select, exp.Aliases): 424 pos_alias = select.aliases[0] 425 explode_alias = select.aliases[1] 426 alias = select.replace(exp.alias_(select.this, "", copy=False)) 427 else: 428 alias = select.replace(exp.alias_(select, "")) 429 explode = alias.find(exp.Explode) 430 assert explode 431 432 is_posexplode = isinstance(explode, exp.Posexplode) 433 explode_arg = explode.this 434 435 if isinstance(explode, exp.ExplodeOuter): 436 bracket = explode_arg[0] 437 bracket.set("safe", True) 438 bracket.set("offset", True) 439 explode_arg = exp.func( 440 "IF", 441 exp.func( 442 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 443 ).eq(0), 444 exp.array(bracket, copy=False), 445 explode_arg, 446 ) 447 448 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 449 if isinstance(explode_arg, exp.Column): 450 taken_select_names.add(explode_arg.output_name) 451 452 unnest_source_alias = new_name(taken_source_names, "_u") 453 454 if not explode_alias: 455 explode_alias = new_name(taken_select_names, "col") 456 457 if is_posexplode: 458 pos_alias = new_name(taken_select_names, "pos") 459 460 if not pos_alias: 461 pos_alias = new_name(taken_select_names, "pos") 462 463 alias.set("alias", exp.to_identifier(explode_alias)) 464 465 series_table_alias = series.args["alias"].this 466 column = exp.If( 467 this=exp.column(series_alias, table=series_table_alias).eq( 468 exp.column(pos_alias, table=unnest_source_alias) 469 ), 470 true=exp.column(explode_alias, table=unnest_source_alias), 471 ) 472 473 explode.replace(column) 474 475 if is_posexplode: 476 expressions = expression.expressions 477 expressions.insert( 478 expressions.index(alias) + 1, 479 exp.If( 480 this=exp.column(series_alias, table=series_table_alias).eq( 481 exp.column(pos_alias, table=unnest_source_alias) 482 ), 483 true=exp.column(pos_alias, table=unnest_source_alias), 484 ).as_(pos_alias), 485 ) 486 expression.set("expressions", expressions) 487 488 if not arrays: 489 if expression.args.get("from"): 490 expression.join(series, copy=False, join_type="CROSS") 491 else: 492 expression.from_(series, copy=False) 493 494 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 495 arrays.append(size) 496 497 # trino doesn't support left join unnest with on conditions 498 # if it did, this would be much simpler 499 expression.join( 500 exp.alias_( 501 exp.Unnest( 502 expressions=[explode_arg.copy()], 503 offset=exp.to_identifier(pos_alias), 504 ), 505 unnest_source_alias, 506 table=[explode_alias], 507 ), 508 join_type="CROSS", 509 copy=False, 510 ) 511 512 if index_offset != 1: 513 size = size - 1 514 515 expression.where( 516 exp.column(series_alias, table=series_table_alias) 517 .eq(exp.column(pos_alias, table=unnest_source_alias)) 518 .or_( 519 (exp.column(series_alias, table=series_table_alias) > size).and_( 520 exp.column(pos_alias, table=unnest_source_alias).eq(size) 521 ) 522 ), 523 copy=False, 524 ) 525 526 if arrays: 527 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 528 529 if index_offset != 1: 530 end = end - (1 - index_offset) 531 series.expressions[0].set("end", end) 532 533 return expression 534 535 return _explode_to_unnest
Convert explode/posexplode into unnest.
538def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 539 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 540 if ( 541 isinstance(expression, exp.PERCENTILES) 542 and not isinstance(expression.parent, exp.WithinGroup) 543 and expression.expression 544 ): 545 column = expression.this.pop() 546 expression.set("this", expression.expression.pop()) 547 order = exp.Order(expressions=[exp.Ordered(this=column)]) 548 expression = exp.WithinGroup(this=expression, expression=order) 549 550 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
553def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 554 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 555 if ( 556 isinstance(expression, exp.WithinGroup) 557 and isinstance(expression.this, exp.PERCENTILES) 558 and isinstance(expression.expression, exp.Order) 559 ): 560 quantile = expression.this.this 561 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 562 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 563 564 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
567def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 568 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 569 if isinstance(expression, exp.With) and expression.recursive: 570 next_name = name_sequence("_c_") 571 572 for cte in expression.expressions: 573 if not cte.args["alias"].columns: 574 query = cte.this 575 if isinstance(query, exp.SetOperation): 576 query = query.this 577 578 cte.args["alias"].set( 579 "columns", 580 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 581 ) 582 583 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
586def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 587 """Replace 'epoch' in casts by the equivalent date literal.""" 588 if ( 589 isinstance(expression, (exp.Cast, exp.TryCast)) 590 and expression.name.lower() == "epoch" 591 and expression.to.this in exp.DataType.TEMPORAL_TYPES 592 ): 593 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 594 595 return expression
Replace 'epoch' in casts by the equivalent date literal.
598def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 599 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 600 if isinstance(expression, exp.Select): 601 for join in expression.args.get("joins") or []: 602 on = join.args.get("on") 603 if on and join.kind in ("SEMI", "ANTI"): 604 subquery = exp.select("1").from_(join.this).where(on) 605 exists = exp.Exists(this=subquery) 606 if join.kind == "ANTI": 607 exists = exists.not_(copy=False) 608 609 join.pop() 610 expression.where(exists, copy=False) 611 612 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
615def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 616 """ 617 Converts a query with a FULL OUTER join to a union of identical queries that 618 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 619 for queries that have a single FULL OUTER join. 620 """ 621 if isinstance(expression, exp.Select): 622 full_outer_joins = [ 623 (index, join) 624 for index, join in enumerate(expression.args.get("joins") or []) 625 if join.side == "FULL" 626 ] 627 628 if len(full_outer_joins) == 1: 629 expression_copy = expression.copy() 630 expression.set("limit", None) 631 index, full_outer_join = full_outer_joins[0] 632 full_outer_join.set("side", "left") 633 expression_copy.args["joins"][index].set("side", "right") 634 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 635 636 return exp.union(expression, expression_copy, copy=False) 637 638 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
641def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 642 """ 643 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 644 defined at the top-level, so for example queries like: 645 646 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 647 648 are invalid in those dialects. This transformation can be used to ensure all CTEs are 649 moved to the top level so that the final SQL code is valid from a syntax standpoint. 650 651 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 652 """ 653 top_level_with = expression.args.get("with") 654 for inner_with in expression.find_all(exp.With): 655 if inner_with.parent is expression: 656 continue 657 658 if not top_level_with: 659 top_level_with = inner_with.pop() 660 expression.set("with", top_level_with) 661 else: 662 if inner_with.recursive: 663 top_level_with.set("recursive", True) 664 665 parent_cte = inner_with.find_ancestor(exp.CTE) 666 inner_with.pop() 667 668 if parent_cte: 669 i = top_level_with.expressions.index(parent_cte) 670 top_level_with.expressions[i:i] = inner_with.expressions 671 top_level_with.set("expressions", top_level_with.expressions) 672 else: 673 top_level_with.set( 674 "expressions", top_level_with.expressions + inner_with.expressions 675 ) 676 677 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
680def ensure_bools(expression: exp.Expression) -> exp.Expression: 681 """Converts numeric values used in conditions into explicit boolean expressions.""" 682 from sqlglot.optimizer.canonicalize import ensure_bools 683 684 def _ensure_bool(node: exp.Expression) -> None: 685 if ( 686 node.is_number 687 or ( 688 not isinstance(node, exp.SubqueryPredicate) 689 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 690 ) 691 or (isinstance(node, exp.Column) and not node.type) 692 ): 693 node.replace(node.neq(0)) 694 695 for node in expression.walk(): 696 ensure_bools(node, _ensure_bool) 697 698 return expression
Converts numeric values used in conditions into explicit boolean expressions.
719def ctas_with_tmp_tables_to_create_tmp_view( 720 expression: exp.Expression, 721 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 722) -> exp.Expression: 723 assert isinstance(expression, exp.Create) 724 properties = expression.args.get("properties") 725 temporary = any( 726 isinstance(prop, exp.TemporaryProperty) 727 for prop in (properties.expressions if properties else []) 728 ) 729 730 # CTAS with temp tables map to CREATE TEMPORARY VIEW 731 if expression.kind == "TABLE" and temporary: 732 if expression.expression: 733 return exp.Create( 734 kind="TEMPORARY VIEW", 735 this=expression.this, 736 expression=expression.expression, 737 ) 738 return tmp_storage_provider(expression) 739 740 return expression
743def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 744 """ 745 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 746 PARTITIONED BY value is an array of column names, they are transformed into a schema. 747 The corresponding columns are removed from the create statement. 748 """ 749 assert isinstance(expression, exp.Create) 750 has_schema = isinstance(expression.this, exp.Schema) 751 is_partitionable = expression.kind in {"TABLE", "VIEW"} 752 753 if has_schema and is_partitionable: 754 prop = expression.find(exp.PartitionedByProperty) 755 if prop and prop.this and not isinstance(prop.this, exp.Schema): 756 schema = expression.this 757 columns = {v.name.upper() for v in prop.this.expressions} 758 partitions = [col for col in schema.expressions if col.name.upper() in columns] 759 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 760 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 761 expression.set("this", schema) 762 763 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
766def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 767 """ 768 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 769 770 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 771 """ 772 assert isinstance(expression, exp.Create) 773 prop = expression.find(exp.PartitionedByProperty) 774 if ( 775 prop 776 and prop.this 777 and isinstance(prop.this, exp.Schema) 778 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 779 ): 780 prop_this = exp.Tuple( 781 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 782 ) 783 schema = expression.this 784 for e in prop.this.expressions: 785 schema.append("expressions", e) 786 prop.set("this", prop_this) 787 788 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
791def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 792 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 793 if isinstance(expression, exp.Struct): 794 expression.set( 795 "expressions", 796 [ 797 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 798 for e in expression.expressions 799 ], 800 ) 801 802 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
805def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 806 """ 807 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 808 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 809 810 For example, 811 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 812 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 813 814 Args: 815 expression: The AST to remove join marks from. 816 817 Returns: 818 The AST with join marks removed. 819 """ 820 from sqlglot.optimizer.scope import traverse_scope 821 822 for scope in traverse_scope(expression): 823 query = scope.expression 824 825 where = query.args.get("where") 826 joins = query.args.get("joins") 827 828 if not where or not joins: 829 continue 830 831 query_from = query.args["from"] 832 833 # These keep track of the joins to be replaced 834 new_joins: t.Dict[str, exp.Join] = {} 835 old_joins = {join.alias_or_name: join for join in joins} 836 837 for column in scope.columns: 838 if not column.args.get("join_mark"): 839 continue 840 841 predicate = column.find_ancestor(exp.Predicate, exp.Select) 842 assert isinstance( 843 predicate, exp.Binary 844 ), "Columns can only be marked with (+) when involved in a binary operation" 845 846 predicate_parent = predicate.parent 847 join_predicate = predicate.pop() 848 849 left_columns = [ 850 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 851 ] 852 right_columns = [ 853 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 854 ] 855 856 assert not ( 857 left_columns and right_columns 858 ), "The (+) marker cannot appear in both sides of a binary predicate" 859 860 marked_column_tables = set() 861 for col in left_columns or right_columns: 862 table = col.table 863 assert table, f"Column {col} needs to be qualified with a table" 864 865 col.set("join_mark", False) 866 marked_column_tables.add(table) 867 868 assert ( 869 len(marked_column_tables) == 1 870 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 871 872 join_this = old_joins.get(col.table, query_from).this 873 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 874 875 # Upsert new_join into new_joins dictionary 876 new_join_alias_or_name = new_join.alias_or_name 877 existing_join = new_joins.get(new_join_alias_or_name) 878 if existing_join: 879 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 880 else: 881 new_joins[new_join_alias_or_name] = new_join 882 883 # If the parent of the target predicate is a binary node, then it now has only one child 884 if isinstance(predicate_parent, exp.Binary): 885 if predicate_parent.left is None: 886 predicate_parent.replace(predicate_parent.right) 887 else: 888 predicate_parent.replace(predicate_parent.left) 889 890 if query_from.alias_or_name in new_joins: 891 only_old_joins = old_joins.keys() - new_joins.keys() 892 assert ( 893 len(only_old_joins) >= 1 894 ), "Cannot determine which table to use in the new FROM clause" 895 896 new_from_name = list(only_old_joins)[0] 897 query.set("from", exp.From(this=old_joins[new_from_name].this)) 898 899 query.set("joins", list(new_joins.values())) 900 901 if not where.this: 902 where.pop() 903 904 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.