sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 97 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 98 corresponding expression to avoid creating invalid column references. 99 """ 100 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 101 taken = set(expression.named_selects) 102 for select in expression.selects: 103 if not select.alias_or_name: 104 alias = find_new_name(taken, "_c") 105 select.replace(exp.alias_(select, alias)) 106 taken.add(alias) 107 108 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 109 qualify_filters = expression.args["qualify"].pop().this 110 expression_by_alias = { 111 select.alias: select.this 112 for select in expression.selects 113 if isinstance(select, exp.Alias) 114 } 115 116 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 117 for select_candidate in qualify_filters.find_all(select_candidates): 118 if isinstance(select_candidate, exp.Window): 119 if expression_by_alias: 120 for column in select_candidate.find_all(exp.Column): 121 expr = expression_by_alias.get(column.name) 122 if expr: 123 column.replace(expr) 124 125 alias = find_new_name(expression.named_selects, "_w") 126 expression.select(exp.alias_(select_candidate, alias), copy=False) 127 column = exp.column(alias) 128 129 if isinstance(select_candidate.parent, exp.Qualify): 130 qualify_filters = column 131 else: 132 select_candidate.replace(column) 133 elif select_candidate.name not in expression.named_selects: 134 expression.select(select_candidate.copy(), copy=False) 135 136 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 137 qualify_filters, copy=False 138 ) 139 140 return expression 141 142 143def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 144 """ 145 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 146 other expressions. This transforms removes the precision from parameterized types in expressions. 147 """ 148 for node in expression.find_all(exp.DataType): 149 node.set( 150 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 151 ) 152 153 return expression 154 155 156def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 157 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 158 from sqlglot.optimizer.scope import find_all_in_scope 159 160 if isinstance(expression, exp.Select): 161 unnest_aliases = { 162 unnest.alias 163 for unnest in find_all_in_scope(expression, exp.Unnest) 164 if isinstance(unnest.parent, (exp.From, exp.Join)) 165 } 166 if unnest_aliases: 167 for column in expression.find_all(exp.Column): 168 if column.table in unnest_aliases: 169 column.set("table", None) 170 elif column.db in unnest_aliases: 171 column.set("db", None) 172 173 return expression 174 175 176def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 177 """Convert cross join unnest into lateral view explode.""" 178 if isinstance(expression, exp.Select): 179 for join in expression.args.get("joins") or []: 180 unnest = join.this 181 182 if isinstance(unnest, exp.Unnest): 183 alias = unnest.args.get("alias") 184 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 185 186 expression.args["joins"].remove(join) 187 188 for e, column in zip(unnest.expressions, alias.columns if alias else []): 189 expression.append( 190 "laterals", 191 exp.Lateral( 192 this=udtf(this=e), 193 view=True, 194 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 195 ), 196 ) 197 198 return expression 199 200 201def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 202 """Convert explode/posexplode into unnest.""" 203 204 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 205 if isinstance(expression, exp.Select): 206 from sqlglot.optimizer.scope import Scope 207 208 taken_select_names = set(expression.named_selects) 209 taken_source_names = {name for name, _ in Scope(expression).references} 210 211 def new_name(names: t.Set[str], name: str) -> str: 212 name = find_new_name(names, name) 213 names.add(name) 214 return name 215 216 arrays: t.List[exp.Condition] = [] 217 series_alias = new_name(taken_select_names, "pos") 218 series = exp.alias_( 219 exp.Unnest( 220 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 221 ), 222 new_name(taken_source_names, "_u"), 223 table=[series_alias], 224 ) 225 226 # we use list here because expression.selects is mutated inside the loop 227 for select in list(expression.selects): 228 explode = select.find(exp.Explode) 229 230 if explode: 231 pos_alias = "" 232 explode_alias = "" 233 234 if isinstance(select, exp.Alias): 235 explode_alias = select.args["alias"] 236 alias = select 237 elif isinstance(select, exp.Aliases): 238 pos_alias = select.aliases[0] 239 explode_alias = select.aliases[1] 240 alias = select.replace(exp.alias_(select.this, "", copy=False)) 241 else: 242 alias = select.replace(exp.alias_(select, "")) 243 explode = alias.find(exp.Explode) 244 assert explode 245 246 is_posexplode = isinstance(explode, exp.Posexplode) 247 explode_arg = explode.this 248 249 if isinstance(explode, exp.ExplodeOuter): 250 bracket = explode_arg[0] 251 bracket.set("safe", True) 252 bracket.set("offset", True) 253 explode_arg = exp.func( 254 "IF", 255 exp.func( 256 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 257 ).eq(0), 258 exp.array(bracket, copy=False), 259 explode_arg, 260 ) 261 262 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 263 if isinstance(explode_arg, exp.Column): 264 taken_select_names.add(explode_arg.output_name) 265 266 unnest_source_alias = new_name(taken_source_names, "_u") 267 268 if not explode_alias: 269 explode_alias = new_name(taken_select_names, "col") 270 271 if is_posexplode: 272 pos_alias = new_name(taken_select_names, "pos") 273 274 if not pos_alias: 275 pos_alias = new_name(taken_select_names, "pos") 276 277 alias.set("alias", exp.to_identifier(explode_alias)) 278 279 series_table_alias = series.args["alias"].this 280 column = exp.If( 281 this=exp.column(series_alias, table=series_table_alias).eq( 282 exp.column(pos_alias, table=unnest_source_alias) 283 ), 284 true=exp.column(explode_alias, table=unnest_source_alias), 285 ) 286 287 explode.replace(column) 288 289 if is_posexplode: 290 expressions = expression.expressions 291 expressions.insert( 292 expressions.index(alias) + 1, 293 exp.If( 294 this=exp.column(series_alias, table=series_table_alias).eq( 295 exp.column(pos_alias, table=unnest_source_alias) 296 ), 297 true=exp.column(pos_alias, table=unnest_source_alias), 298 ).as_(pos_alias), 299 ) 300 expression.set("expressions", expressions) 301 302 if not arrays: 303 if expression.args.get("from"): 304 expression.join(series, copy=False, join_type="CROSS") 305 else: 306 expression.from_(series, copy=False) 307 308 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 309 arrays.append(size) 310 311 # trino doesn't support left join unnest with on conditions 312 # if it did, this would be much simpler 313 expression.join( 314 exp.alias_( 315 exp.Unnest( 316 expressions=[explode_arg.copy()], 317 offset=exp.to_identifier(pos_alias), 318 ), 319 unnest_source_alias, 320 table=[explode_alias], 321 ), 322 join_type="CROSS", 323 copy=False, 324 ) 325 326 if index_offset != 1: 327 size = size - 1 328 329 expression.where( 330 exp.column(series_alias, table=series_table_alias) 331 .eq(exp.column(pos_alias, table=unnest_source_alias)) 332 .or_( 333 (exp.column(series_alias, table=series_table_alias) > size).and_( 334 exp.column(pos_alias, table=unnest_source_alias).eq(size) 335 ) 336 ), 337 copy=False, 338 ) 339 340 if arrays: 341 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 342 343 if index_offset != 1: 344 end = end - (1 - index_offset) 345 series.expressions[0].set("end", end) 346 347 return expression 348 349 return _explode_to_unnest 350 351 352def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 353 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 354 if ( 355 isinstance(expression, exp.PERCENTILES) 356 and not isinstance(expression.parent, exp.WithinGroup) 357 and expression.expression 358 ): 359 column = expression.this.pop() 360 expression.set("this", expression.expression.pop()) 361 order = exp.Order(expressions=[exp.Ordered(this=column)]) 362 expression = exp.WithinGroup(this=expression, expression=order) 363 364 return expression 365 366 367def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 368 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 369 if ( 370 isinstance(expression, exp.WithinGroup) 371 and isinstance(expression.this, exp.PERCENTILES) 372 and isinstance(expression.expression, exp.Order) 373 ): 374 quantile = expression.this.this 375 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 376 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 377 378 return expression 379 380 381def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 382 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 383 if isinstance(expression, exp.With) and expression.recursive: 384 next_name = name_sequence("_c_") 385 386 for cte in expression.expressions: 387 if not cte.args["alias"].columns: 388 query = cte.this 389 if isinstance(query, exp.Union): 390 query = query.this 391 392 cte.args["alias"].set( 393 "columns", 394 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 395 ) 396 397 return expression 398 399 400def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 401 """Replace 'epoch' in casts by the equivalent date literal.""" 402 if ( 403 isinstance(expression, (exp.Cast, exp.TryCast)) 404 and expression.name.lower() == "epoch" 405 and expression.to.this in exp.DataType.TEMPORAL_TYPES 406 ): 407 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 408 409 return expression 410 411 412def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 413 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 414 if isinstance(expression, exp.Select): 415 for join in expression.args.get("joins") or []: 416 on = join.args.get("on") 417 if on and join.kind in ("SEMI", "ANTI"): 418 subquery = exp.select("1").from_(join.this).where(on) 419 exists = exp.Exists(this=subquery) 420 if join.kind == "ANTI": 421 exists = exists.not_(copy=False) 422 423 join.pop() 424 expression.where(exists, copy=False) 425 426 return expression 427 428 429def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 430 """ 431 Converts a query with a FULL OUTER join to a union of identical queries that 432 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 433 for queries that have a single FULL OUTER join. 434 """ 435 if isinstance(expression, exp.Select): 436 full_outer_joins = [ 437 (index, join) 438 for index, join in enumerate(expression.args.get("joins") or []) 439 if join.side == "FULL" 440 ] 441 442 if len(full_outer_joins) == 1: 443 expression_copy = expression.copy() 444 expression.set("limit", None) 445 index, full_outer_join = full_outer_joins[0] 446 full_outer_join.set("side", "left") 447 expression_copy.args["joins"][index].set("side", "right") 448 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 449 450 return exp.union(expression, expression_copy, copy=False) 451 452 return expression 453 454 455def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 456 """ 457 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 458 defined at the top-level, so for example queries like: 459 460 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 461 462 are invalid in those dialects. This transformation can be used to ensure all CTEs are 463 moved to the top level so that the final SQL code is valid from a syntax standpoint. 464 465 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 466 """ 467 top_level_with = expression.args.get("with") 468 for node in expression.find_all(exp.With): 469 if node.parent is expression: 470 continue 471 472 inner_with = node.pop() 473 if not top_level_with: 474 top_level_with = inner_with 475 expression.set("with", top_level_with) 476 else: 477 if inner_with.recursive: 478 top_level_with.set("recursive", True) 479 480 top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions) 481 482 return expression 483 484 485def ensure_bools(expression: exp.Expression) -> exp.Expression: 486 """Converts numeric values used in conditions into explicit boolean expressions.""" 487 from sqlglot.optimizer.canonicalize import ensure_bools 488 489 def _ensure_bool(node: exp.Expression) -> None: 490 if ( 491 node.is_number 492 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 493 or (isinstance(node, exp.Column) and not node.type) 494 ): 495 node.replace(node.neq(0)) 496 497 for node in expression.walk(): 498 ensure_bools(node, _ensure_bool) 499 500 return expression 501 502 503def unqualify_columns(expression: exp.Expression) -> exp.Expression: 504 for column in expression.find_all(exp.Column): 505 # We only wanna pop off the table, db, catalog args 506 for part in column.parts[:-1]: 507 part.pop() 508 509 return expression 510 511 512def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 513 assert isinstance(expression, exp.Create) 514 for constraint in expression.find_all(exp.UniqueColumnConstraint): 515 if constraint.parent: 516 constraint.parent.pop() 517 518 return expression 519 520 521def ctas_with_tmp_tables_to_create_tmp_view( 522 expression: exp.Expression, 523 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 524) -> exp.Expression: 525 assert isinstance(expression, exp.Create) 526 properties = expression.args.get("properties") 527 temporary = any( 528 isinstance(prop, exp.TemporaryProperty) 529 for prop in (properties.expressions if properties else []) 530 ) 531 532 # CTAS with temp tables map to CREATE TEMPORARY VIEW 533 if expression.kind == "TABLE" and temporary: 534 if expression.expression: 535 return exp.Create( 536 kind="TEMPORARY VIEW", 537 this=expression.this, 538 expression=expression.expression, 539 ) 540 return tmp_storage_provider(expression) 541 542 return expression 543 544 545def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 546 """ 547 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 548 PARTITIONED BY value is an array of column names, they are transformed into a schema. 549 The corresponding columns are removed from the create statement. 550 """ 551 assert isinstance(expression, exp.Create) 552 has_schema = isinstance(expression.this, exp.Schema) 553 is_partitionable = expression.kind in {"TABLE", "VIEW"} 554 555 if has_schema and is_partitionable: 556 prop = expression.find(exp.PartitionedByProperty) 557 if prop and prop.this and not isinstance(prop.this, exp.Schema): 558 schema = expression.this 559 columns = {v.name.upper() for v in prop.this.expressions} 560 partitions = [col for col in schema.expressions if col.name.upper() in columns] 561 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 562 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 563 expression.set("this", schema) 564 565 return expression 566 567 568def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 569 """ 570 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 571 572 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 573 """ 574 assert isinstance(expression, exp.Create) 575 prop = expression.find(exp.PartitionedByProperty) 576 if ( 577 prop 578 and prop.this 579 and isinstance(prop.this, exp.Schema) 580 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 581 ): 582 prop_this = exp.Tuple( 583 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 584 ) 585 schema = expression.this 586 for e in prop.this.expressions: 587 schema.append("expressions", e) 588 prop.set("this", prop_this) 589 590 return expression 591 592 593def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 594 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 595 if isinstance(expression, exp.Struct): 596 expression.set( 597 "expressions", 598 [ 599 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 600 for e in expression.expressions 601 ], 602 ) 603 604 return expression 605 606 607def preprocess( 608 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 609) -> t.Callable[[Generator, exp.Expression], str]: 610 """ 611 Creates a new transform by chaining a sequence of transformations and converts the resulting 612 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 613 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 614 615 Args: 616 transforms: sequence of transform functions. These will be called in order. 617 618 Returns: 619 Function that can be used as a generator transform. 620 """ 621 622 def _to_sql(self, expression: exp.Expression) -> str: 623 expression_type = type(expression) 624 625 expression = transforms[0](expression) 626 for transform in transforms[1:]: 627 expression = transform(expression) 628 629 _sql_handler = getattr(self, expression.key + "_sql", None) 630 if _sql_handler: 631 return _sql_handler(expression) 632 633 transforms_handler = self.TRANSFORMS.get(type(expression)) 634 if transforms_handler: 635 if expression_type is type(expression): 636 if isinstance(expression, exp.Func): 637 return self.function_fallback_sql(expression) 638 639 # Ensures we don't enter an infinite loop. This can happen when the original expression 640 # has the same type as the final expression and there's no _sql method available for it, 641 # because then it'd re-enter _to_sql. 642 raise ValueError( 643 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 644 ) 645 646 return transforms_handler(self, expression) 647 648 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 649 650 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 98 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 99 corresponding expression to avoid creating invalid column references. 100 """ 101 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 102 taken = set(expression.named_selects) 103 for select in expression.selects: 104 if not select.alias_or_name: 105 alias = find_new_name(taken, "_c") 106 select.replace(exp.alias_(select, alias)) 107 taken.add(alias) 108 109 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 110 qualify_filters = expression.args["qualify"].pop().this 111 expression_by_alias = { 112 select.alias: select.this 113 for select in expression.selects 114 if isinstance(select, exp.Alias) 115 } 116 117 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 118 for select_candidate in qualify_filters.find_all(select_candidates): 119 if isinstance(select_candidate, exp.Window): 120 if expression_by_alias: 121 for column in select_candidate.find_all(exp.Column): 122 expr = expression_by_alias.get(column.name) 123 if expr: 124 column.replace(expr) 125 126 alias = find_new_name(expression.named_selects, "_w") 127 expression.select(exp.alias_(select_candidate, alias), copy=False) 128 column = exp.column(alias) 129 130 if isinstance(select_candidate.parent, exp.Qualify): 131 qualify_filters = column 132 else: 133 select_candidate.replace(column) 134 elif select_candidate.name not in expression.named_selects: 135 expression.select(select_candidate.copy(), copy=False) 136 137 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 138 qualify_filters, copy=False 139 ) 140 141 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
144def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 145 """ 146 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 147 other expressions. This transforms removes the precision from parameterized types in expressions. 148 """ 149 for node in expression.find_all(exp.DataType): 150 node.set( 151 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 152 ) 153 154 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
157def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 158 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 159 from sqlglot.optimizer.scope import find_all_in_scope 160 161 if isinstance(expression, exp.Select): 162 unnest_aliases = { 163 unnest.alias 164 for unnest in find_all_in_scope(expression, exp.Unnest) 165 if isinstance(unnest.parent, (exp.From, exp.Join)) 166 } 167 if unnest_aliases: 168 for column in expression.find_all(exp.Column): 169 if column.table in unnest_aliases: 170 column.set("table", None) 171 elif column.db in unnest_aliases: 172 column.set("db", None) 173 174 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
177def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 178 """Convert cross join unnest into lateral view explode.""" 179 if isinstance(expression, exp.Select): 180 for join in expression.args.get("joins") or []: 181 unnest = join.this 182 183 if isinstance(unnest, exp.Unnest): 184 alias = unnest.args.get("alias") 185 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 186 187 expression.args["joins"].remove(join) 188 189 for e, column in zip(unnest.expressions, alias.columns if alias else []): 190 expression.append( 191 "laterals", 192 exp.Lateral( 193 this=udtf(this=e), 194 view=True, 195 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 196 ), 197 ) 198 199 return expression
Convert cross join unnest into lateral view explode.
202def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 203 """Convert explode/posexplode into unnest.""" 204 205 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 206 if isinstance(expression, exp.Select): 207 from sqlglot.optimizer.scope import Scope 208 209 taken_select_names = set(expression.named_selects) 210 taken_source_names = {name for name, _ in Scope(expression).references} 211 212 def new_name(names: t.Set[str], name: str) -> str: 213 name = find_new_name(names, name) 214 names.add(name) 215 return name 216 217 arrays: t.List[exp.Condition] = [] 218 series_alias = new_name(taken_select_names, "pos") 219 series = exp.alias_( 220 exp.Unnest( 221 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 222 ), 223 new_name(taken_source_names, "_u"), 224 table=[series_alias], 225 ) 226 227 # we use list here because expression.selects is mutated inside the loop 228 for select in list(expression.selects): 229 explode = select.find(exp.Explode) 230 231 if explode: 232 pos_alias = "" 233 explode_alias = "" 234 235 if isinstance(select, exp.Alias): 236 explode_alias = select.args["alias"] 237 alias = select 238 elif isinstance(select, exp.Aliases): 239 pos_alias = select.aliases[0] 240 explode_alias = select.aliases[1] 241 alias = select.replace(exp.alias_(select.this, "", copy=False)) 242 else: 243 alias = select.replace(exp.alias_(select, "")) 244 explode = alias.find(exp.Explode) 245 assert explode 246 247 is_posexplode = isinstance(explode, exp.Posexplode) 248 explode_arg = explode.this 249 250 if isinstance(explode, exp.ExplodeOuter): 251 bracket = explode_arg[0] 252 bracket.set("safe", True) 253 bracket.set("offset", True) 254 explode_arg = exp.func( 255 "IF", 256 exp.func( 257 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 258 ).eq(0), 259 exp.array(bracket, copy=False), 260 explode_arg, 261 ) 262 263 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 264 if isinstance(explode_arg, exp.Column): 265 taken_select_names.add(explode_arg.output_name) 266 267 unnest_source_alias = new_name(taken_source_names, "_u") 268 269 if not explode_alias: 270 explode_alias = new_name(taken_select_names, "col") 271 272 if is_posexplode: 273 pos_alias = new_name(taken_select_names, "pos") 274 275 if not pos_alias: 276 pos_alias = new_name(taken_select_names, "pos") 277 278 alias.set("alias", exp.to_identifier(explode_alias)) 279 280 series_table_alias = series.args["alias"].this 281 column = exp.If( 282 this=exp.column(series_alias, table=series_table_alias).eq( 283 exp.column(pos_alias, table=unnest_source_alias) 284 ), 285 true=exp.column(explode_alias, table=unnest_source_alias), 286 ) 287 288 explode.replace(column) 289 290 if is_posexplode: 291 expressions = expression.expressions 292 expressions.insert( 293 expressions.index(alias) + 1, 294 exp.If( 295 this=exp.column(series_alias, table=series_table_alias).eq( 296 exp.column(pos_alias, table=unnest_source_alias) 297 ), 298 true=exp.column(pos_alias, table=unnest_source_alias), 299 ).as_(pos_alias), 300 ) 301 expression.set("expressions", expressions) 302 303 if not arrays: 304 if expression.args.get("from"): 305 expression.join(series, copy=False, join_type="CROSS") 306 else: 307 expression.from_(series, copy=False) 308 309 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 310 arrays.append(size) 311 312 # trino doesn't support left join unnest with on conditions 313 # if it did, this would be much simpler 314 expression.join( 315 exp.alias_( 316 exp.Unnest( 317 expressions=[explode_arg.copy()], 318 offset=exp.to_identifier(pos_alias), 319 ), 320 unnest_source_alias, 321 table=[explode_alias], 322 ), 323 join_type="CROSS", 324 copy=False, 325 ) 326 327 if index_offset != 1: 328 size = size - 1 329 330 expression.where( 331 exp.column(series_alias, table=series_table_alias) 332 .eq(exp.column(pos_alias, table=unnest_source_alias)) 333 .or_( 334 (exp.column(series_alias, table=series_table_alias) > size).and_( 335 exp.column(pos_alias, table=unnest_source_alias).eq(size) 336 ) 337 ), 338 copy=False, 339 ) 340 341 if arrays: 342 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 343 344 if index_offset != 1: 345 end = end - (1 - index_offset) 346 series.expressions[0].set("end", end) 347 348 return expression 349 350 return _explode_to_unnest
Convert explode/posexplode into unnest.
353def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 354 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 355 if ( 356 isinstance(expression, exp.PERCENTILES) 357 and not isinstance(expression.parent, exp.WithinGroup) 358 and expression.expression 359 ): 360 column = expression.this.pop() 361 expression.set("this", expression.expression.pop()) 362 order = exp.Order(expressions=[exp.Ordered(this=column)]) 363 expression = exp.WithinGroup(this=expression, expression=order) 364 365 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
368def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 369 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 370 if ( 371 isinstance(expression, exp.WithinGroup) 372 and isinstance(expression.this, exp.PERCENTILES) 373 and isinstance(expression.expression, exp.Order) 374 ): 375 quantile = expression.this.this 376 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 377 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 378 379 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
382def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 383 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 384 if isinstance(expression, exp.With) and expression.recursive: 385 next_name = name_sequence("_c_") 386 387 for cte in expression.expressions: 388 if not cte.args["alias"].columns: 389 query = cte.this 390 if isinstance(query, exp.Union): 391 query = query.this 392 393 cte.args["alias"].set( 394 "columns", 395 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 396 ) 397 398 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
401def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 402 """Replace 'epoch' in casts by the equivalent date literal.""" 403 if ( 404 isinstance(expression, (exp.Cast, exp.TryCast)) 405 and expression.name.lower() == "epoch" 406 and expression.to.this in exp.DataType.TEMPORAL_TYPES 407 ): 408 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 409 410 return expression
Replace 'epoch' in casts by the equivalent date literal.
413def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 414 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 415 if isinstance(expression, exp.Select): 416 for join in expression.args.get("joins") or []: 417 on = join.args.get("on") 418 if on and join.kind in ("SEMI", "ANTI"): 419 subquery = exp.select("1").from_(join.this).where(on) 420 exists = exp.Exists(this=subquery) 421 if join.kind == "ANTI": 422 exists = exists.not_(copy=False) 423 424 join.pop() 425 expression.where(exists, copy=False) 426 427 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
430def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 431 """ 432 Converts a query with a FULL OUTER join to a union of identical queries that 433 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 434 for queries that have a single FULL OUTER join. 435 """ 436 if isinstance(expression, exp.Select): 437 full_outer_joins = [ 438 (index, join) 439 for index, join in enumerate(expression.args.get("joins") or []) 440 if join.side == "FULL" 441 ] 442 443 if len(full_outer_joins) == 1: 444 expression_copy = expression.copy() 445 expression.set("limit", None) 446 index, full_outer_join = full_outer_joins[0] 447 full_outer_join.set("side", "left") 448 expression_copy.args["joins"][index].set("side", "right") 449 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 450 451 return exp.union(expression, expression_copy, copy=False) 452 453 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
456def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 457 """ 458 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 459 defined at the top-level, so for example queries like: 460 461 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 462 463 are invalid in those dialects. This transformation can be used to ensure all CTEs are 464 moved to the top level so that the final SQL code is valid from a syntax standpoint. 465 466 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 467 """ 468 top_level_with = expression.args.get("with") 469 for node in expression.find_all(exp.With): 470 if node.parent is expression: 471 continue 472 473 inner_with = node.pop() 474 if not top_level_with: 475 top_level_with = inner_with 476 expression.set("with", top_level_with) 477 else: 478 if inner_with.recursive: 479 top_level_with.set("recursive", True) 480 481 top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions) 482 483 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
486def ensure_bools(expression: exp.Expression) -> exp.Expression: 487 """Converts numeric values used in conditions into explicit boolean expressions.""" 488 from sqlglot.optimizer.canonicalize import ensure_bools 489 490 def _ensure_bool(node: exp.Expression) -> None: 491 if ( 492 node.is_number 493 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 494 or (isinstance(node, exp.Column) and not node.type) 495 ): 496 node.replace(node.neq(0)) 497 498 for node in expression.walk(): 499 ensure_bools(node, _ensure_bool) 500 501 return expression
Converts numeric values used in conditions into explicit boolean expressions.
522def ctas_with_tmp_tables_to_create_tmp_view( 523 expression: exp.Expression, 524 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 525) -> exp.Expression: 526 assert isinstance(expression, exp.Create) 527 properties = expression.args.get("properties") 528 temporary = any( 529 isinstance(prop, exp.TemporaryProperty) 530 for prop in (properties.expressions if properties else []) 531 ) 532 533 # CTAS with temp tables map to CREATE TEMPORARY VIEW 534 if expression.kind == "TABLE" and temporary: 535 if expression.expression: 536 return exp.Create( 537 kind="TEMPORARY VIEW", 538 this=expression.this, 539 expression=expression.expression, 540 ) 541 return tmp_storage_provider(expression) 542 543 return expression
546def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 547 """ 548 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 549 PARTITIONED BY value is an array of column names, they are transformed into a schema. 550 The corresponding columns are removed from the create statement. 551 """ 552 assert isinstance(expression, exp.Create) 553 has_schema = isinstance(expression.this, exp.Schema) 554 is_partitionable = expression.kind in {"TABLE", "VIEW"} 555 556 if has_schema and is_partitionable: 557 prop = expression.find(exp.PartitionedByProperty) 558 if prop and prop.this and not isinstance(prop.this, exp.Schema): 559 schema = expression.this 560 columns = {v.name.upper() for v in prop.this.expressions} 561 partitions = [col for col in schema.expressions if col.name.upper() in columns] 562 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 563 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 564 expression.set("this", schema) 565 566 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
569def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 570 """ 571 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 572 573 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 574 """ 575 assert isinstance(expression, exp.Create) 576 prop = expression.find(exp.PartitionedByProperty) 577 if ( 578 prop 579 and prop.this 580 and isinstance(prop.this, exp.Schema) 581 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 582 ): 583 prop_this = exp.Tuple( 584 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 585 ) 586 schema = expression.this 587 for e in prop.this.expressions: 588 schema.append("expressions", e) 589 prop.set("this", prop_this) 590 591 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
594def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 595 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 596 if isinstance(expression, exp.Struct): 597 expression.set( 598 "expressions", 599 [ 600 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 601 for e in expression.expressions 602 ], 603 ) 604 605 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
608def preprocess( 609 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 610) -> t.Callable[[Generator, exp.Expression], str]: 611 """ 612 Creates a new transform by chaining a sequence of transformations and converts the resulting 613 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 614 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 615 616 Args: 617 transforms: sequence of transform functions. These will be called in order. 618 619 Returns: 620 Function that can be used as a generator transform. 621 """ 622 623 def _to_sql(self, expression: exp.Expression) -> str: 624 expression_type = type(expression) 625 626 expression = transforms[0](expression) 627 for transform in transforms[1:]: 628 expression = transform(expression) 629 630 _sql_handler = getattr(self, expression.key + "_sql", None) 631 if _sql_handler: 632 return _sql_handler(expression) 633 634 transforms_handler = self.TRANSFORMS.get(type(expression)) 635 if transforms_handler: 636 if expression_type is type(expression): 637 if isinstance(expression, exp.Func): 638 return self.function_fallback_sql(expression) 639 640 # Ensures we don't enter an infinite loop. This can happen when the original expression 641 # has the same type as the final expression and there's no _sql method available for it, 642 # because then it'd re-enter _to_sql. 643 raise ValueError( 644 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 645 ) 646 647 return transforms_handler(self, expression) 648 649 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 650 651 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.