sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def preprocess( 13 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 14) -> t.Callable[[Generator, exp.Expression], str]: 15 """ 16 Creates a new transform by chaining a sequence of transformations and converts the resulting 17 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 18 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 19 20 Args: 21 transforms: sequence of transform functions. These will be called in order. 22 23 Returns: 24 Function that can be used as a generator transform. 25 """ 26 27 def _to_sql(self, expression: exp.Expression) -> str: 28 expression_type = type(expression) 29 30 expression = transforms[0](expression) 31 for transform in transforms[1:]: 32 expression = transform(expression) 33 34 _sql_handler = getattr(self, expression.key + "_sql", None) 35 if _sql_handler: 36 return _sql_handler(expression) 37 38 transforms_handler = self.TRANSFORMS.get(type(expression)) 39 if transforms_handler: 40 if expression_type is type(expression): 41 if isinstance(expression, exp.Func): 42 return self.function_fallback_sql(expression) 43 44 # Ensures we don't enter an infinite loop. This can happen when the original expression 45 # has the same type as the final expression and there's no _sql method available for it, 46 # because then it'd re-enter _to_sql. 47 raise ValueError( 48 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 49 ) 50 51 return transforms_handler(self, expression) 52 53 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 54 55 return _to_sql 56 57 58def unalias_group(expression: exp.Expression) -> exp.Expression: 59 """ 60 Replace references to select aliases in GROUP BY clauses. 61 62 Example: 63 >>> import sqlglot 64 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 65 'SELECT a AS b FROM x GROUP BY 1' 66 67 Args: 68 expression: the expression that will be transformed. 69 70 Returns: 71 The transformed expression. 72 """ 73 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 74 aliased_selects = { 75 e.alias: i 76 for i, e in enumerate(expression.parent.expressions, start=1) 77 if isinstance(e, exp.Alias) 78 } 79 80 for group_by in expression.expressions: 81 if ( 82 isinstance(group_by, exp.Column) 83 and not group_by.table 84 and group_by.name in aliased_selects 85 ): 86 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 87 88 return expression 89 90 91def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 92 """ 93 Convert SELECT DISTINCT ON statements to a subquery with a window function. 94 95 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 96 97 Args: 98 expression: the expression that will be transformed. 99 100 Returns: 101 The transformed expression. 102 """ 103 if ( 104 isinstance(expression, exp.Select) 105 and expression.args.get("distinct") 106 and expression.args["distinct"].args.get("on") 107 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 108 ): 109 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 110 outer_selects = expression.selects 111 row_number = find_new_name(expression.named_selects, "_row_number") 112 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 113 order = expression.args.get("order") 114 115 if order: 116 window.set("order", order.pop()) 117 else: 118 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 119 120 window = exp.alias_(window, row_number) 121 expression.select(window, copy=False) 122 123 return ( 124 exp.select(*outer_selects, copy=False) 125 .from_(expression.subquery("_t", copy=False), copy=False) 126 .where(exp.column(row_number).eq(1), copy=False) 127 ) 128 129 return expression 130 131 132def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 133 """ 134 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 135 136 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 137 https://docs.snowflake.com/en/sql-reference/constructs/qualify 138 139 Some dialects don't support window functions in the WHERE clause, so we need to include them as 140 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 141 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 142 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 143 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 144 corresponding expression to avoid creating invalid column references. 145 """ 146 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 147 taken = set(expression.named_selects) 148 for select in expression.selects: 149 if not select.alias_or_name: 150 alias = find_new_name(taken, "_c") 151 select.replace(exp.alias_(select, alias)) 152 taken.add(alias) 153 154 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 155 alias_or_name = select.alias_or_name 156 identifier = select.args.get("alias") or select.this 157 if isinstance(identifier, exp.Identifier): 158 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 159 return alias_or_name 160 161 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 162 qualify_filters = expression.args["qualify"].pop().this 163 expression_by_alias = { 164 select.alias: select.this 165 for select in expression.selects 166 if isinstance(select, exp.Alias) 167 } 168 169 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 170 for select_candidate in qualify_filters.find_all(select_candidates): 171 if isinstance(select_candidate, exp.Window): 172 if expression_by_alias: 173 for column in select_candidate.find_all(exp.Column): 174 expr = expression_by_alias.get(column.name) 175 if expr: 176 column.replace(expr) 177 178 alias = find_new_name(expression.named_selects, "_w") 179 expression.select(exp.alias_(select_candidate, alias), copy=False) 180 column = exp.column(alias) 181 182 if isinstance(select_candidate.parent, exp.Qualify): 183 qualify_filters = column 184 else: 185 select_candidate.replace(column) 186 elif select_candidate.name not in expression.named_selects: 187 expression.select(select_candidate.copy(), copy=False) 188 189 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 190 qualify_filters, copy=False 191 ) 192 193 return expression 194 195 196def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 197 """ 198 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 199 other expressions. This transforms removes the precision from parameterized types in expressions. 200 """ 201 for node in expression.find_all(exp.DataType): 202 node.set( 203 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 204 ) 205 206 return expression 207 208 209def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 210 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 211 from sqlglot.optimizer.scope import find_all_in_scope 212 213 if isinstance(expression, exp.Select): 214 unnest_aliases = { 215 unnest.alias 216 for unnest in find_all_in_scope(expression, exp.Unnest) 217 if isinstance(unnest.parent, (exp.From, exp.Join)) 218 } 219 if unnest_aliases: 220 for column in expression.find_all(exp.Column): 221 if column.table in unnest_aliases: 222 column.set("table", None) 223 elif column.db in unnest_aliases: 224 column.set("db", None) 225 226 return expression 227 228 229def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 230 """Convert cross join unnest into lateral view explode.""" 231 if isinstance(expression, exp.Select): 232 from_ = expression.args.get("from") 233 234 if from_ and isinstance(from_.this, exp.Unnest): 235 unnest = from_.this 236 alias = unnest.args.get("alias") 237 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 238 this, *expressions = unnest.expressions 239 unnest.replace( 240 exp.Table( 241 this=udtf( 242 this=this, 243 expressions=expressions, 244 ), 245 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 246 ) 247 ) 248 249 for join in expression.args.get("joins") or []: 250 unnest = join.this 251 252 if isinstance(unnest, exp.Unnest): 253 alias = unnest.args.get("alias") 254 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 255 256 expression.args["joins"].remove(join) 257 258 for e, column in zip(unnest.expressions, alias.columns if alias else []): 259 expression.append( 260 "laterals", 261 exp.Lateral( 262 this=udtf(this=e), 263 view=True, 264 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 265 ), 266 ) 267 268 return expression 269 270 271def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 272 """Convert explode/posexplode into unnest.""" 273 274 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 275 if isinstance(expression, exp.Select): 276 from sqlglot.optimizer.scope import Scope 277 278 taken_select_names = set(expression.named_selects) 279 taken_source_names = {name for name, _ in Scope(expression).references} 280 281 def new_name(names: t.Set[str], name: str) -> str: 282 name = find_new_name(names, name) 283 names.add(name) 284 return name 285 286 arrays: t.List[exp.Condition] = [] 287 series_alias = new_name(taken_select_names, "pos") 288 series = exp.alias_( 289 exp.Unnest( 290 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 291 ), 292 new_name(taken_source_names, "_u"), 293 table=[series_alias], 294 ) 295 296 # we use list here because expression.selects is mutated inside the loop 297 for select in list(expression.selects): 298 explode = select.find(exp.Explode) 299 300 if explode: 301 pos_alias = "" 302 explode_alias = "" 303 304 if isinstance(select, exp.Alias): 305 explode_alias = select.args["alias"] 306 alias = select 307 elif isinstance(select, exp.Aliases): 308 pos_alias = select.aliases[0] 309 explode_alias = select.aliases[1] 310 alias = select.replace(exp.alias_(select.this, "", copy=False)) 311 else: 312 alias = select.replace(exp.alias_(select, "")) 313 explode = alias.find(exp.Explode) 314 assert explode 315 316 is_posexplode = isinstance(explode, exp.Posexplode) 317 explode_arg = explode.this 318 319 if isinstance(explode, exp.ExplodeOuter): 320 bracket = explode_arg[0] 321 bracket.set("safe", True) 322 bracket.set("offset", True) 323 explode_arg = exp.func( 324 "IF", 325 exp.func( 326 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 327 ).eq(0), 328 exp.array(bracket, copy=False), 329 explode_arg, 330 ) 331 332 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 333 if isinstance(explode_arg, exp.Column): 334 taken_select_names.add(explode_arg.output_name) 335 336 unnest_source_alias = new_name(taken_source_names, "_u") 337 338 if not explode_alias: 339 explode_alias = new_name(taken_select_names, "col") 340 341 if is_posexplode: 342 pos_alias = new_name(taken_select_names, "pos") 343 344 if not pos_alias: 345 pos_alias = new_name(taken_select_names, "pos") 346 347 alias.set("alias", exp.to_identifier(explode_alias)) 348 349 series_table_alias = series.args["alias"].this 350 column = exp.If( 351 this=exp.column(series_alias, table=series_table_alias).eq( 352 exp.column(pos_alias, table=unnest_source_alias) 353 ), 354 true=exp.column(explode_alias, table=unnest_source_alias), 355 ) 356 357 explode.replace(column) 358 359 if is_posexplode: 360 expressions = expression.expressions 361 expressions.insert( 362 expressions.index(alias) + 1, 363 exp.If( 364 this=exp.column(series_alias, table=series_table_alias).eq( 365 exp.column(pos_alias, table=unnest_source_alias) 366 ), 367 true=exp.column(pos_alias, table=unnest_source_alias), 368 ).as_(pos_alias), 369 ) 370 expression.set("expressions", expressions) 371 372 if not arrays: 373 if expression.args.get("from"): 374 expression.join(series, copy=False, join_type="CROSS") 375 else: 376 expression.from_(series, copy=False) 377 378 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 379 arrays.append(size) 380 381 # trino doesn't support left join unnest with on conditions 382 # if it did, this would be much simpler 383 expression.join( 384 exp.alias_( 385 exp.Unnest( 386 expressions=[explode_arg.copy()], 387 offset=exp.to_identifier(pos_alias), 388 ), 389 unnest_source_alias, 390 table=[explode_alias], 391 ), 392 join_type="CROSS", 393 copy=False, 394 ) 395 396 if index_offset != 1: 397 size = size - 1 398 399 expression.where( 400 exp.column(series_alias, table=series_table_alias) 401 .eq(exp.column(pos_alias, table=unnest_source_alias)) 402 .or_( 403 (exp.column(series_alias, table=series_table_alias) > size).and_( 404 exp.column(pos_alias, table=unnest_source_alias).eq(size) 405 ) 406 ), 407 copy=False, 408 ) 409 410 if arrays: 411 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 412 413 if index_offset != 1: 414 end = end - (1 - index_offset) 415 series.expressions[0].set("end", end) 416 417 return expression 418 419 return _explode_to_unnest 420 421 422def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 423 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 424 if ( 425 isinstance(expression, exp.PERCENTILES) 426 and not isinstance(expression.parent, exp.WithinGroup) 427 and expression.expression 428 ): 429 column = expression.this.pop() 430 expression.set("this", expression.expression.pop()) 431 order = exp.Order(expressions=[exp.Ordered(this=column)]) 432 expression = exp.WithinGroup(this=expression, expression=order) 433 434 return expression 435 436 437def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 438 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 439 if ( 440 isinstance(expression, exp.WithinGroup) 441 and isinstance(expression.this, exp.PERCENTILES) 442 and isinstance(expression.expression, exp.Order) 443 ): 444 quantile = expression.this.this 445 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 446 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 447 448 return expression 449 450 451def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 452 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 453 if isinstance(expression, exp.With) and expression.recursive: 454 next_name = name_sequence("_c_") 455 456 for cte in expression.expressions: 457 if not cte.args["alias"].columns: 458 query = cte.this 459 if isinstance(query, exp.SetOperation): 460 query = query.this 461 462 cte.args["alias"].set( 463 "columns", 464 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 465 ) 466 467 return expression 468 469 470def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 471 """Replace 'epoch' in casts by the equivalent date literal.""" 472 if ( 473 isinstance(expression, (exp.Cast, exp.TryCast)) 474 and expression.name.lower() == "epoch" 475 and expression.to.this in exp.DataType.TEMPORAL_TYPES 476 ): 477 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 478 479 return expression 480 481 482def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 483 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 484 if isinstance(expression, exp.Select): 485 for join in expression.args.get("joins") or []: 486 on = join.args.get("on") 487 if on and join.kind in ("SEMI", "ANTI"): 488 subquery = exp.select("1").from_(join.this).where(on) 489 exists = exp.Exists(this=subquery) 490 if join.kind == "ANTI": 491 exists = exists.not_(copy=False) 492 493 join.pop() 494 expression.where(exists, copy=False) 495 496 return expression 497 498 499def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 500 """ 501 Converts a query with a FULL OUTER join to a union of identical queries that 502 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 503 for queries that have a single FULL OUTER join. 504 """ 505 if isinstance(expression, exp.Select): 506 full_outer_joins = [ 507 (index, join) 508 for index, join in enumerate(expression.args.get("joins") or []) 509 if join.side == "FULL" 510 ] 511 512 if len(full_outer_joins) == 1: 513 expression_copy = expression.copy() 514 expression.set("limit", None) 515 index, full_outer_join = full_outer_joins[0] 516 full_outer_join.set("side", "left") 517 expression_copy.args["joins"][index].set("side", "right") 518 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 519 520 return exp.union(expression, expression_copy, copy=False) 521 522 return expression 523 524 525def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 526 """ 527 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 528 defined at the top-level, so for example queries like: 529 530 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 531 532 are invalid in those dialects. This transformation can be used to ensure all CTEs are 533 moved to the top level so that the final SQL code is valid from a syntax standpoint. 534 535 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 536 """ 537 top_level_with = expression.args.get("with") 538 for inner_with in expression.find_all(exp.With): 539 if inner_with.parent is expression: 540 continue 541 542 if not top_level_with: 543 top_level_with = inner_with.pop() 544 expression.set("with", top_level_with) 545 else: 546 if inner_with.recursive: 547 top_level_with.set("recursive", True) 548 549 parent_cte = inner_with.find_ancestor(exp.CTE) 550 inner_with.pop() 551 552 if parent_cte: 553 i = top_level_with.expressions.index(parent_cte) 554 top_level_with.expressions[i:i] = inner_with.expressions 555 top_level_with.set("expressions", top_level_with.expressions) 556 else: 557 top_level_with.set( 558 "expressions", top_level_with.expressions + inner_with.expressions 559 ) 560 561 return expression 562 563 564def ensure_bools(expression: exp.Expression) -> exp.Expression: 565 """Converts numeric values used in conditions into explicit boolean expressions.""" 566 from sqlglot.optimizer.canonicalize import ensure_bools 567 568 def _ensure_bool(node: exp.Expression) -> None: 569 if ( 570 node.is_number 571 or ( 572 not isinstance(node, exp.SubqueryPredicate) 573 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 574 ) 575 or (isinstance(node, exp.Column) and not node.type) 576 ): 577 node.replace(node.neq(0)) 578 579 for node in expression.walk(): 580 ensure_bools(node, _ensure_bool) 581 582 return expression 583 584 585def unqualify_columns(expression: exp.Expression) -> exp.Expression: 586 for column in expression.find_all(exp.Column): 587 # We only wanna pop off the table, db, catalog args 588 for part in column.parts[:-1]: 589 part.pop() 590 591 return expression 592 593 594def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 595 assert isinstance(expression, exp.Create) 596 for constraint in expression.find_all(exp.UniqueColumnConstraint): 597 if constraint.parent: 598 constraint.parent.pop() 599 600 return expression 601 602 603def ctas_with_tmp_tables_to_create_tmp_view( 604 expression: exp.Expression, 605 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 606) -> exp.Expression: 607 assert isinstance(expression, exp.Create) 608 properties = expression.args.get("properties") 609 temporary = any( 610 isinstance(prop, exp.TemporaryProperty) 611 for prop in (properties.expressions if properties else []) 612 ) 613 614 # CTAS with temp tables map to CREATE TEMPORARY VIEW 615 if expression.kind == "TABLE" and temporary: 616 if expression.expression: 617 return exp.Create( 618 kind="TEMPORARY VIEW", 619 this=expression.this, 620 expression=expression.expression, 621 ) 622 return tmp_storage_provider(expression) 623 624 return expression 625 626 627def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 628 """ 629 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 630 PARTITIONED BY value is an array of column names, they are transformed into a schema. 631 The corresponding columns are removed from the create statement. 632 """ 633 assert isinstance(expression, exp.Create) 634 has_schema = isinstance(expression.this, exp.Schema) 635 is_partitionable = expression.kind in {"TABLE", "VIEW"} 636 637 if has_schema and is_partitionable: 638 prop = expression.find(exp.PartitionedByProperty) 639 if prop and prop.this and not isinstance(prop.this, exp.Schema): 640 schema = expression.this 641 columns = {v.name.upper() for v in prop.this.expressions} 642 partitions = [col for col in schema.expressions if col.name.upper() in columns] 643 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 644 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 645 expression.set("this", schema) 646 647 return expression 648 649 650def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 651 """ 652 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 653 654 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 655 """ 656 assert isinstance(expression, exp.Create) 657 prop = expression.find(exp.PartitionedByProperty) 658 if ( 659 prop 660 and prop.this 661 and isinstance(prop.this, exp.Schema) 662 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 663 ): 664 prop_this = exp.Tuple( 665 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 666 ) 667 schema = expression.this 668 for e in prop.this.expressions: 669 schema.append("expressions", e) 670 prop.set("this", prop_this) 671 672 return expression 673 674 675def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 676 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 677 if isinstance(expression, exp.Struct): 678 expression.set( 679 "expressions", 680 [ 681 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 682 for e in expression.expressions 683 ], 684 ) 685 686 return expression 687 688 689def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 690 """ 691 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 692 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 693 694 For example, 695 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 696 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 697 698 Args: 699 expression: The AST to remove join marks from. 700 701 Returns: 702 The AST with join marks removed. 703 """ 704 from sqlglot.optimizer.scope import traverse_scope 705 706 for scope in traverse_scope(expression): 707 query = scope.expression 708 709 where = query.args.get("where") 710 joins = query.args.get("joins") 711 712 if not where or not joins: 713 continue 714 715 query_from = query.args["from"] 716 717 # These keep track of the joins to be replaced 718 new_joins: t.Dict[str, exp.Join] = {} 719 old_joins = {join.alias_or_name: join for join in joins} 720 721 for column in scope.columns: 722 if not column.args.get("join_mark"): 723 continue 724 725 predicate = column.find_ancestor(exp.Predicate, exp.Select) 726 assert isinstance( 727 predicate, exp.Binary 728 ), "Columns can only be marked with (+) when involved in a binary operation" 729 730 predicate_parent = predicate.parent 731 join_predicate = predicate.pop() 732 733 left_columns = [ 734 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 735 ] 736 right_columns = [ 737 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 738 ] 739 740 assert not ( 741 left_columns and right_columns 742 ), "The (+) marker cannot appear in both sides of a binary predicate" 743 744 marked_column_tables = set() 745 for col in left_columns or right_columns: 746 table = col.table 747 assert table, f"Column {col} needs to be qualified with a table" 748 749 col.set("join_mark", False) 750 marked_column_tables.add(table) 751 752 assert ( 753 len(marked_column_tables) == 1 754 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 755 756 join_this = old_joins.get(col.table, query_from).this 757 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 758 759 # Upsert new_join into new_joins dictionary 760 new_join_alias_or_name = new_join.alias_or_name 761 existing_join = new_joins.get(new_join_alias_or_name) 762 if existing_join: 763 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 764 else: 765 new_joins[new_join_alias_or_name] = new_join 766 767 # If the parent of the target predicate is a binary node, then it now has only one child 768 if isinstance(predicate_parent, exp.Binary): 769 if predicate_parent.left is None: 770 predicate_parent.replace(predicate_parent.right) 771 else: 772 predicate_parent.replace(predicate_parent.left) 773 774 if query_from.alias_or_name in new_joins: 775 only_old_joins = old_joins.keys() - new_joins.keys() 776 assert ( 777 len(only_old_joins) >= 1 778 ), "Cannot determine which table to use in the new FROM clause" 779 780 new_from_name = list(only_old_joins)[0] 781 query.set("from", exp.From(this=old_joins[new_from_name].this)) 782 783 query.set("joins", list(new_joins.values())) 784 785 if not where.this: 786 where.pop() 787 788 return expression
13def preprocess( 14 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 15) -> t.Callable[[Generator, exp.Expression], str]: 16 """ 17 Creates a new transform by chaining a sequence of transformations and converts the resulting 18 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 19 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 20 21 Args: 22 transforms: sequence of transform functions. These will be called in order. 23 24 Returns: 25 Function that can be used as a generator transform. 26 """ 27 28 def _to_sql(self, expression: exp.Expression) -> str: 29 expression_type = type(expression) 30 31 expression = transforms[0](expression) 32 for transform in transforms[1:]: 33 expression = transform(expression) 34 35 _sql_handler = getattr(self, expression.key + "_sql", None) 36 if _sql_handler: 37 return _sql_handler(expression) 38 39 transforms_handler = self.TRANSFORMS.get(type(expression)) 40 if transforms_handler: 41 if expression_type is type(expression): 42 if isinstance(expression, exp.Func): 43 return self.function_fallback_sql(expression) 44 45 # Ensures we don't enter an infinite loop. This can happen when the original expression 46 # has the same type as the final expression and there's no _sql method available for it, 47 # because then it'd re-enter _to_sql. 48 raise ValueError( 49 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 50 ) 51 52 return transforms_handler(self, expression) 53 54 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 55 56 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
59def unalias_group(expression: exp.Expression) -> exp.Expression: 60 """ 61 Replace references to select aliases in GROUP BY clauses. 62 63 Example: 64 >>> import sqlglot 65 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 66 'SELECT a AS b FROM x GROUP BY 1' 67 68 Args: 69 expression: the expression that will be transformed. 70 71 Returns: 72 The transformed expression. 73 """ 74 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 75 aliased_selects = { 76 e.alias: i 77 for i, e in enumerate(expression.parent.expressions, start=1) 78 if isinstance(e, exp.Alias) 79 } 80 81 for group_by in expression.expressions: 82 if ( 83 isinstance(group_by, exp.Column) 84 and not group_by.table 85 and group_by.name in aliased_selects 86 ): 87 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 88 89 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
92def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 93 """ 94 Convert SELECT DISTINCT ON statements to a subquery with a window function. 95 96 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 97 98 Args: 99 expression: the expression that will be transformed. 100 101 Returns: 102 The transformed expression. 103 """ 104 if ( 105 isinstance(expression, exp.Select) 106 and expression.args.get("distinct") 107 and expression.args["distinct"].args.get("on") 108 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 109 ): 110 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 111 outer_selects = expression.selects 112 row_number = find_new_name(expression.named_selects, "_row_number") 113 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 114 order = expression.args.get("order") 115 116 if order: 117 window.set("order", order.pop()) 118 else: 119 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 120 121 window = exp.alias_(window, row_number) 122 expression.select(window, copy=False) 123 124 return ( 125 exp.select(*outer_selects, copy=False) 126 .from_(expression.subquery("_t", copy=False), copy=False) 127 .where(exp.column(row_number).eq(1), copy=False) 128 ) 129 130 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
133def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 134 """ 135 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 136 137 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 138 https://docs.snowflake.com/en/sql-reference/constructs/qualify 139 140 Some dialects don't support window functions in the WHERE clause, so we need to include them as 141 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 142 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 143 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 144 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 145 corresponding expression to avoid creating invalid column references. 146 """ 147 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 148 taken = set(expression.named_selects) 149 for select in expression.selects: 150 if not select.alias_or_name: 151 alias = find_new_name(taken, "_c") 152 select.replace(exp.alias_(select, alias)) 153 taken.add(alias) 154 155 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 156 alias_or_name = select.alias_or_name 157 identifier = select.args.get("alias") or select.this 158 if isinstance(identifier, exp.Identifier): 159 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 160 return alias_or_name 161 162 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 163 qualify_filters = expression.args["qualify"].pop().this 164 expression_by_alias = { 165 select.alias: select.this 166 for select in expression.selects 167 if isinstance(select, exp.Alias) 168 } 169 170 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 171 for select_candidate in qualify_filters.find_all(select_candidates): 172 if isinstance(select_candidate, exp.Window): 173 if expression_by_alias: 174 for column in select_candidate.find_all(exp.Column): 175 expr = expression_by_alias.get(column.name) 176 if expr: 177 column.replace(expr) 178 179 alias = find_new_name(expression.named_selects, "_w") 180 expression.select(exp.alias_(select_candidate, alias), copy=False) 181 column = exp.column(alias) 182 183 if isinstance(select_candidate.parent, exp.Qualify): 184 qualify_filters = column 185 else: 186 select_candidate.replace(column) 187 elif select_candidate.name not in expression.named_selects: 188 expression.select(select_candidate.copy(), copy=False) 189 190 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 191 qualify_filters, copy=False 192 ) 193 194 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
197def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 198 """ 199 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 200 other expressions. This transforms removes the precision from parameterized types in expressions. 201 """ 202 for node in expression.find_all(exp.DataType): 203 node.set( 204 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 205 ) 206 207 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
210def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 211 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 212 from sqlglot.optimizer.scope import find_all_in_scope 213 214 if isinstance(expression, exp.Select): 215 unnest_aliases = { 216 unnest.alias 217 for unnest in find_all_in_scope(expression, exp.Unnest) 218 if isinstance(unnest.parent, (exp.From, exp.Join)) 219 } 220 if unnest_aliases: 221 for column in expression.find_all(exp.Column): 222 if column.table in unnest_aliases: 223 column.set("table", None) 224 elif column.db in unnest_aliases: 225 column.set("db", None) 226 227 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
230def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 231 """Convert cross join unnest into lateral view explode.""" 232 if isinstance(expression, exp.Select): 233 from_ = expression.args.get("from") 234 235 if from_ and isinstance(from_.this, exp.Unnest): 236 unnest = from_.this 237 alias = unnest.args.get("alias") 238 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 239 this, *expressions = unnest.expressions 240 unnest.replace( 241 exp.Table( 242 this=udtf( 243 this=this, 244 expressions=expressions, 245 ), 246 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 247 ) 248 ) 249 250 for join in expression.args.get("joins") or []: 251 unnest = join.this 252 253 if isinstance(unnest, exp.Unnest): 254 alias = unnest.args.get("alias") 255 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 256 257 expression.args["joins"].remove(join) 258 259 for e, column in zip(unnest.expressions, alias.columns if alias else []): 260 expression.append( 261 "laterals", 262 exp.Lateral( 263 this=udtf(this=e), 264 view=True, 265 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 266 ), 267 ) 268 269 return expression
Convert cross join unnest into lateral view explode.
272def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 273 """Convert explode/posexplode into unnest.""" 274 275 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 276 if isinstance(expression, exp.Select): 277 from sqlglot.optimizer.scope import Scope 278 279 taken_select_names = set(expression.named_selects) 280 taken_source_names = {name for name, _ in Scope(expression).references} 281 282 def new_name(names: t.Set[str], name: str) -> str: 283 name = find_new_name(names, name) 284 names.add(name) 285 return name 286 287 arrays: t.List[exp.Condition] = [] 288 series_alias = new_name(taken_select_names, "pos") 289 series = exp.alias_( 290 exp.Unnest( 291 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 292 ), 293 new_name(taken_source_names, "_u"), 294 table=[series_alias], 295 ) 296 297 # we use list here because expression.selects is mutated inside the loop 298 for select in list(expression.selects): 299 explode = select.find(exp.Explode) 300 301 if explode: 302 pos_alias = "" 303 explode_alias = "" 304 305 if isinstance(select, exp.Alias): 306 explode_alias = select.args["alias"] 307 alias = select 308 elif isinstance(select, exp.Aliases): 309 pos_alias = select.aliases[0] 310 explode_alias = select.aliases[1] 311 alias = select.replace(exp.alias_(select.this, "", copy=False)) 312 else: 313 alias = select.replace(exp.alias_(select, "")) 314 explode = alias.find(exp.Explode) 315 assert explode 316 317 is_posexplode = isinstance(explode, exp.Posexplode) 318 explode_arg = explode.this 319 320 if isinstance(explode, exp.ExplodeOuter): 321 bracket = explode_arg[0] 322 bracket.set("safe", True) 323 bracket.set("offset", True) 324 explode_arg = exp.func( 325 "IF", 326 exp.func( 327 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 328 ).eq(0), 329 exp.array(bracket, copy=False), 330 explode_arg, 331 ) 332 333 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 334 if isinstance(explode_arg, exp.Column): 335 taken_select_names.add(explode_arg.output_name) 336 337 unnest_source_alias = new_name(taken_source_names, "_u") 338 339 if not explode_alias: 340 explode_alias = new_name(taken_select_names, "col") 341 342 if is_posexplode: 343 pos_alias = new_name(taken_select_names, "pos") 344 345 if not pos_alias: 346 pos_alias = new_name(taken_select_names, "pos") 347 348 alias.set("alias", exp.to_identifier(explode_alias)) 349 350 series_table_alias = series.args["alias"].this 351 column = exp.If( 352 this=exp.column(series_alias, table=series_table_alias).eq( 353 exp.column(pos_alias, table=unnest_source_alias) 354 ), 355 true=exp.column(explode_alias, table=unnest_source_alias), 356 ) 357 358 explode.replace(column) 359 360 if is_posexplode: 361 expressions = expression.expressions 362 expressions.insert( 363 expressions.index(alias) + 1, 364 exp.If( 365 this=exp.column(series_alias, table=series_table_alias).eq( 366 exp.column(pos_alias, table=unnest_source_alias) 367 ), 368 true=exp.column(pos_alias, table=unnest_source_alias), 369 ).as_(pos_alias), 370 ) 371 expression.set("expressions", expressions) 372 373 if not arrays: 374 if expression.args.get("from"): 375 expression.join(series, copy=False, join_type="CROSS") 376 else: 377 expression.from_(series, copy=False) 378 379 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 380 arrays.append(size) 381 382 # trino doesn't support left join unnest with on conditions 383 # if it did, this would be much simpler 384 expression.join( 385 exp.alias_( 386 exp.Unnest( 387 expressions=[explode_arg.copy()], 388 offset=exp.to_identifier(pos_alias), 389 ), 390 unnest_source_alias, 391 table=[explode_alias], 392 ), 393 join_type="CROSS", 394 copy=False, 395 ) 396 397 if index_offset != 1: 398 size = size - 1 399 400 expression.where( 401 exp.column(series_alias, table=series_table_alias) 402 .eq(exp.column(pos_alias, table=unnest_source_alias)) 403 .or_( 404 (exp.column(series_alias, table=series_table_alias) > size).and_( 405 exp.column(pos_alias, table=unnest_source_alias).eq(size) 406 ) 407 ), 408 copy=False, 409 ) 410 411 if arrays: 412 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 413 414 if index_offset != 1: 415 end = end - (1 - index_offset) 416 series.expressions[0].set("end", end) 417 418 return expression 419 420 return _explode_to_unnest
Convert explode/posexplode into unnest.
423def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 424 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 425 if ( 426 isinstance(expression, exp.PERCENTILES) 427 and not isinstance(expression.parent, exp.WithinGroup) 428 and expression.expression 429 ): 430 column = expression.this.pop() 431 expression.set("this", expression.expression.pop()) 432 order = exp.Order(expressions=[exp.Ordered(this=column)]) 433 expression = exp.WithinGroup(this=expression, expression=order) 434 435 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
438def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 439 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 440 if ( 441 isinstance(expression, exp.WithinGroup) 442 and isinstance(expression.this, exp.PERCENTILES) 443 and isinstance(expression.expression, exp.Order) 444 ): 445 quantile = expression.this.this 446 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 447 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 448 449 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
452def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 453 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 454 if isinstance(expression, exp.With) and expression.recursive: 455 next_name = name_sequence("_c_") 456 457 for cte in expression.expressions: 458 if not cte.args["alias"].columns: 459 query = cte.this 460 if isinstance(query, exp.SetOperation): 461 query = query.this 462 463 cte.args["alias"].set( 464 "columns", 465 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 466 ) 467 468 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
471def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 472 """Replace 'epoch' in casts by the equivalent date literal.""" 473 if ( 474 isinstance(expression, (exp.Cast, exp.TryCast)) 475 and expression.name.lower() == "epoch" 476 and expression.to.this in exp.DataType.TEMPORAL_TYPES 477 ): 478 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 479 480 return expression
Replace 'epoch' in casts by the equivalent date literal.
483def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 484 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 485 if isinstance(expression, exp.Select): 486 for join in expression.args.get("joins") or []: 487 on = join.args.get("on") 488 if on and join.kind in ("SEMI", "ANTI"): 489 subquery = exp.select("1").from_(join.this).where(on) 490 exists = exp.Exists(this=subquery) 491 if join.kind == "ANTI": 492 exists = exists.not_(copy=False) 493 494 join.pop() 495 expression.where(exists, copy=False) 496 497 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
500def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 501 """ 502 Converts a query with a FULL OUTER join to a union of identical queries that 503 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 504 for queries that have a single FULL OUTER join. 505 """ 506 if isinstance(expression, exp.Select): 507 full_outer_joins = [ 508 (index, join) 509 for index, join in enumerate(expression.args.get("joins") or []) 510 if join.side == "FULL" 511 ] 512 513 if len(full_outer_joins) == 1: 514 expression_copy = expression.copy() 515 expression.set("limit", None) 516 index, full_outer_join = full_outer_joins[0] 517 full_outer_join.set("side", "left") 518 expression_copy.args["joins"][index].set("side", "right") 519 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 520 521 return exp.union(expression, expression_copy, copy=False) 522 523 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
526def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 527 """ 528 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 529 defined at the top-level, so for example queries like: 530 531 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 532 533 are invalid in those dialects. This transformation can be used to ensure all CTEs are 534 moved to the top level so that the final SQL code is valid from a syntax standpoint. 535 536 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 537 """ 538 top_level_with = expression.args.get("with") 539 for inner_with in expression.find_all(exp.With): 540 if inner_with.parent is expression: 541 continue 542 543 if not top_level_with: 544 top_level_with = inner_with.pop() 545 expression.set("with", top_level_with) 546 else: 547 if inner_with.recursive: 548 top_level_with.set("recursive", True) 549 550 parent_cte = inner_with.find_ancestor(exp.CTE) 551 inner_with.pop() 552 553 if parent_cte: 554 i = top_level_with.expressions.index(parent_cte) 555 top_level_with.expressions[i:i] = inner_with.expressions 556 top_level_with.set("expressions", top_level_with.expressions) 557 else: 558 top_level_with.set( 559 "expressions", top_level_with.expressions + inner_with.expressions 560 ) 561 562 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
565def ensure_bools(expression: exp.Expression) -> exp.Expression: 566 """Converts numeric values used in conditions into explicit boolean expressions.""" 567 from sqlglot.optimizer.canonicalize import ensure_bools 568 569 def _ensure_bool(node: exp.Expression) -> None: 570 if ( 571 node.is_number 572 or ( 573 not isinstance(node, exp.SubqueryPredicate) 574 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 575 ) 576 or (isinstance(node, exp.Column) and not node.type) 577 ): 578 node.replace(node.neq(0)) 579 580 for node in expression.walk(): 581 ensure_bools(node, _ensure_bool) 582 583 return expression
Converts numeric values used in conditions into explicit boolean expressions.
604def ctas_with_tmp_tables_to_create_tmp_view( 605 expression: exp.Expression, 606 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 607) -> exp.Expression: 608 assert isinstance(expression, exp.Create) 609 properties = expression.args.get("properties") 610 temporary = any( 611 isinstance(prop, exp.TemporaryProperty) 612 for prop in (properties.expressions if properties else []) 613 ) 614 615 # CTAS with temp tables map to CREATE TEMPORARY VIEW 616 if expression.kind == "TABLE" and temporary: 617 if expression.expression: 618 return exp.Create( 619 kind="TEMPORARY VIEW", 620 this=expression.this, 621 expression=expression.expression, 622 ) 623 return tmp_storage_provider(expression) 624 625 return expression
628def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 629 """ 630 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 631 PARTITIONED BY value is an array of column names, they are transformed into a schema. 632 The corresponding columns are removed from the create statement. 633 """ 634 assert isinstance(expression, exp.Create) 635 has_schema = isinstance(expression.this, exp.Schema) 636 is_partitionable = expression.kind in {"TABLE", "VIEW"} 637 638 if has_schema and is_partitionable: 639 prop = expression.find(exp.PartitionedByProperty) 640 if prop and prop.this and not isinstance(prop.this, exp.Schema): 641 schema = expression.this 642 columns = {v.name.upper() for v in prop.this.expressions} 643 partitions = [col for col in schema.expressions if col.name.upper() in columns] 644 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 645 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 646 expression.set("this", schema) 647 648 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
651def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 652 """ 653 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 654 655 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 656 """ 657 assert isinstance(expression, exp.Create) 658 prop = expression.find(exp.PartitionedByProperty) 659 if ( 660 prop 661 and prop.this 662 and isinstance(prop.this, exp.Schema) 663 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 664 ): 665 prop_this = exp.Tuple( 666 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 667 ) 668 schema = expression.this 669 for e in prop.this.expressions: 670 schema.append("expressions", e) 671 prop.set("this", prop_this) 672 673 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
676def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 677 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 678 if isinstance(expression, exp.Struct): 679 expression.set( 680 "expressions", 681 [ 682 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 683 for e in expression.expressions 684 ], 685 ) 686 687 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
690def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 691 """ 692 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 693 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 694 695 For example, 696 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 697 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 698 699 Args: 700 expression: The AST to remove join marks from. 701 702 Returns: 703 The AST with join marks removed. 704 """ 705 from sqlglot.optimizer.scope import traverse_scope 706 707 for scope in traverse_scope(expression): 708 query = scope.expression 709 710 where = query.args.get("where") 711 joins = query.args.get("joins") 712 713 if not where or not joins: 714 continue 715 716 query_from = query.args["from"] 717 718 # These keep track of the joins to be replaced 719 new_joins: t.Dict[str, exp.Join] = {} 720 old_joins = {join.alias_or_name: join for join in joins} 721 722 for column in scope.columns: 723 if not column.args.get("join_mark"): 724 continue 725 726 predicate = column.find_ancestor(exp.Predicate, exp.Select) 727 assert isinstance( 728 predicate, exp.Binary 729 ), "Columns can only be marked with (+) when involved in a binary operation" 730 731 predicate_parent = predicate.parent 732 join_predicate = predicate.pop() 733 734 left_columns = [ 735 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 736 ] 737 right_columns = [ 738 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 739 ] 740 741 assert not ( 742 left_columns and right_columns 743 ), "The (+) marker cannot appear in both sides of a binary predicate" 744 745 marked_column_tables = set() 746 for col in left_columns or right_columns: 747 table = col.table 748 assert table, f"Column {col} needs to be qualified with a table" 749 750 col.set("join_mark", False) 751 marked_column_tables.add(table) 752 753 assert ( 754 len(marked_column_tables) == 1 755 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 756 757 join_this = old_joins.get(col.table, query_from).this 758 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 759 760 # Upsert new_join into new_joins dictionary 761 new_join_alias_or_name = new_join.alias_or_name 762 existing_join = new_joins.get(new_join_alias_or_name) 763 if existing_join: 764 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 765 else: 766 new_joins[new_join_alias_or_name] = new_join 767 768 # If the parent of the target predicate is a binary node, then it now has only one child 769 if isinstance(predicate_parent, exp.Binary): 770 if predicate_parent.left is None: 771 predicate_parent.replace(predicate_parent.right) 772 else: 773 predicate_parent.replace(predicate_parent.left) 774 775 if query_from.alias_or_name in new_joins: 776 only_old_joins = old_joins.keys() - new_joins.keys() 777 assert ( 778 len(only_old_joins) >= 1 779 ), "Cannot determine which table to use in the new FROM clause" 780 781 new_from_name = list(only_old_joins)[0] 782 query.set("from", exp.From(this=old_joins[new_from_name].this)) 783 784 query.set("joins", list(new_joins.values())) 785 786 if not where.this: 787 where.pop() 788 789 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.