sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 21 22 23class Dialects(str, Enum): 24 DIALECT = "" 25 26 BIGQUERY = "bigquery" 27 CLICKHOUSE = "clickhouse" 28 DATABRICKS = "databricks" 29 DRILL = "drill" 30 DUCKDB = "duckdb" 31 HIVE = "hive" 32 MYSQL = "mysql" 33 ORACLE = "oracle" 34 POSTGRES = "postgres" 35 PRESTO = "presto" 36 REDSHIFT = "redshift" 37 SNOWFLAKE = "snowflake" 38 SPARK = "spark" 39 SPARK2 = "spark2" 40 SQLITE = "sqlite" 41 STARROCKS = "starrocks" 42 TABLEAU = "tableau" 43 TERADATA = "teradata" 44 TRINO = "trino" 45 TSQL = "tsql" 46 Doris = "doris" 47 48 49class _Dialect(type): 50 classes: t.Dict[str, t.Type[Dialect]] = {} 51 52 def __eq__(cls, other: t.Any) -> bool: 53 if cls is other: 54 return True 55 if isinstance(other, str): 56 return cls is cls.get(other) 57 if isinstance(other, Dialect): 58 return cls is type(other) 59 60 return False 61 62 def __hash__(cls) -> int: 63 return hash(cls.__name__.lower()) 64 65 @classmethod 66 def __getitem__(cls, key: str) -> t.Type[Dialect]: 67 return cls.classes[key] 68 69 @classmethod 70 def get( 71 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 72 ) -> t.Optional[t.Type[Dialect]]: 73 return cls.classes.get(key, default) 74 75 def __new__(cls, clsname, bases, attrs): 76 klass = super().__new__(cls, clsname, bases, attrs) 77 enum = Dialects.__members__.get(clsname.upper()) 78 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 79 80 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 81 klass.FORMAT_TRIE = ( 82 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 83 ) 84 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 85 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 86 87 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 88 89 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 90 klass.parser_class = getattr(klass, "Parser", Parser) 91 klass.generator_class = getattr(klass, "Generator", Generator) 92 93 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 94 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 95 klass.tokenizer_class._IDENTIFIERS.items() 96 )[0] 97 98 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 99 return next( 100 ( 101 (s, e) 102 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 103 if t == token_type 104 ), 105 (None, None), 106 ) 107 108 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 109 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 110 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 111 112 dialect_properties = { 113 **{ 114 k: v 115 for k, v in vars(klass).items() 116 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 117 }, 118 "TOKENIZER_CLASS": klass.tokenizer_class, 119 } 120 121 if enum not in ("", "bigquery"): 122 dialect_properties["SELECT_KINDS"] = () 123 124 # Pass required dialect properties to the tokenizer, parser and generator classes 125 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 126 for name, value in dialect_properties.items(): 127 if hasattr(subclass, name): 128 setattr(subclass, name, value) 129 130 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 131 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 132 133 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 134 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 135 TokenType.ANTI, 136 TokenType.SEMI, 137 } 138 139 klass.generator_class.can_identify = klass.can_identify 140 141 return klass 142 143 144class Dialect(metaclass=_Dialect): 145 # Determines the base index offset for arrays 146 INDEX_OFFSET = 0 147 148 # If true unnest table aliases are considered only as column aliases 149 UNNEST_COLUMN_ONLY = False 150 151 # Determines whether or not the table alias comes after tablesample 152 ALIAS_POST_TABLESAMPLE = False 153 154 # Determines whether or not unquoted identifiers are resolved as uppercase 155 # When set to None, it means that the dialect treats all identifiers as case-insensitive 156 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 157 158 # Determines whether or not an unquoted identifier can start with a digit 159 IDENTIFIERS_CAN_START_WITH_DIGIT = False 160 161 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 162 DPIPE_IS_STRING_CONCAT = True 163 164 # Determines whether or not CONCAT's arguments must be strings 165 STRICT_STRING_CONCAT = False 166 167 # Determines whether or not user-defined data types are supported 168 SUPPORTS_USER_DEFINED_TYPES = True 169 170 # Determines whether or not SEMI/ANTI JOINs are supported 171 SUPPORTS_SEMI_ANTI_JOIN = True 172 173 # Determines how function names are going to be normalized 174 NORMALIZE_FUNCTIONS: bool | str = "upper" 175 176 # Determines whether the base comes first in the LOG function 177 LOG_BASE_FIRST = True 178 179 # Indicates the default null ordering method to use if not explicitly set 180 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 181 NULL_ORDERING = "nulls_are_small" 182 183 # Whether the behavior of a / b depends on the types of a and b. 184 # False means a / b is always float division. 185 # True means a / b is integer division if both a and b are integers. 186 TYPED_DIVISION = False 187 188 # False means 1 / 0 throws an error. 189 # True means 1 / 0 returns null. 190 SAFE_DIVISION = False 191 192 DATE_FORMAT = "'%Y-%m-%d'" 193 DATEINT_FORMAT = "'%Y%m%d'" 194 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 195 196 # Custom time mappings in which the key represents dialect time format 197 # and the value represents a python time format 198 TIME_MAPPING: t.Dict[str, str] = {} 199 200 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 201 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 202 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 203 FORMAT_MAPPING: t.Dict[str, str] = {} 204 205 # Mapping of an unescaped escape sequence to the corresponding character 206 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 207 208 # Columns that are auto-generated by the engine corresponding to this dialect 209 # Such columns may be excluded from SELECT * queries, for example 210 PSEUDOCOLUMNS: t.Set[str] = set() 211 212 # Autofilled 213 tokenizer_class = Tokenizer 214 parser_class = Parser 215 generator_class = Generator 216 217 # A trie of the time_mapping keys 218 TIME_TRIE: t.Dict = {} 219 FORMAT_TRIE: t.Dict = {} 220 221 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 222 INVERSE_TIME_TRIE: t.Dict = {} 223 224 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 225 226 def __eq__(self, other: t.Any) -> bool: 227 return type(self) == other 228 229 def __hash__(self) -> int: 230 return hash(type(self)) 231 232 @classmethod 233 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 234 if not dialect: 235 return cls 236 if isinstance(dialect, _Dialect): 237 return dialect 238 if isinstance(dialect, Dialect): 239 return dialect.__class__ 240 241 result = cls.get(dialect) 242 if not result: 243 raise ValueError(f"Unknown dialect '{dialect}'") 244 245 return result 246 247 @classmethod 248 def format_time( 249 cls, expression: t.Optional[str | exp.Expression] 250 ) -> t.Optional[exp.Expression]: 251 if isinstance(expression, str): 252 return exp.Literal.string( 253 # the time formats are quoted 254 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 255 ) 256 257 if expression and expression.is_string: 258 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 259 260 return expression 261 262 @classmethod 263 def normalize_identifier(cls, expression: E) -> E: 264 """ 265 Normalizes an unquoted identifier to either lower or upper case, thus essentially 266 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 267 they will be normalized to lowercase regardless of being quoted or not. 268 """ 269 if isinstance(expression, exp.Identifier) and ( 270 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 271 ): 272 expression.set( 273 "this", 274 expression.this.upper() 275 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 276 else expression.this.lower(), 277 ) 278 279 return expression 280 281 @classmethod 282 def case_sensitive(cls, text: str) -> bool: 283 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 284 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 285 return False 286 287 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 288 return any(unsafe(char) for char in text) 289 290 @classmethod 291 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 292 """Checks if text can be identified given an identify option. 293 294 Args: 295 text: The text to check. 296 identify: 297 "always" or `True`: Always returns true. 298 "safe": True if the identifier is case-insensitive. 299 300 Returns: 301 Whether or not the given text can be identified. 302 """ 303 if identify is True or identify == "always": 304 return True 305 306 if identify == "safe": 307 return not cls.case_sensitive(text) 308 309 return False 310 311 @classmethod 312 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 313 if isinstance(expression, exp.Identifier): 314 name = expression.this 315 expression.set( 316 "quoted", 317 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 318 ) 319 320 return expression 321 322 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 323 return self.parser(**opts).parse(self.tokenize(sql), sql) 324 325 def parse_into( 326 self, expression_type: exp.IntoType, sql: str, **opts 327 ) -> t.List[t.Optional[exp.Expression]]: 328 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 329 330 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 331 return self.generator(**opts).generate(expression, copy=copy) 332 333 def transpile(self, sql: str, **opts) -> t.List[str]: 334 return [ 335 self.generate(expression, copy=False, **opts) if expression else "" 336 for expression in self.parse(sql) 337 ] 338 339 def tokenize(self, sql: str) -> t.List[Token]: 340 return self.tokenizer.tokenize(sql) 341 342 @property 343 def tokenizer(self) -> Tokenizer: 344 if not hasattr(self, "_tokenizer"): 345 self._tokenizer = self.tokenizer_class() 346 return self._tokenizer 347 348 def parser(self, **opts) -> Parser: 349 return self.parser_class(**opts) 350 351 def generator(self, **opts) -> Generator: 352 return self.generator_class(**opts) 353 354 355DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 356 357 358def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 359 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 360 361 362def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 363 if expression.args.get("accuracy"): 364 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 365 return self.func("APPROX_COUNT_DISTINCT", expression.this) 366 367 368def if_sql( 369 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 370) -> t.Callable[[Generator, exp.If], str]: 371 def _if_sql(self: Generator, expression: exp.If) -> str: 372 return self.func( 373 name, 374 expression.this, 375 expression.args.get("true"), 376 expression.args.get("false") or false_value, 377 ) 378 379 return _if_sql 380 381 382def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 383 return self.binary(expression, "->") 384 385 386def arrow_json_extract_scalar_sql( 387 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 388) -> str: 389 return self.binary(expression, "->>") 390 391 392def inline_array_sql(self: Generator, expression: exp.Array) -> str: 393 return f"[{self.expressions(expression, flat=True)}]" 394 395 396def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 397 return self.like_sql( 398 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 399 ) 400 401 402def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 403 zone = self.sql(expression, "this") 404 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 405 406 407def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 408 if expression.args.get("recursive"): 409 self.unsupported("Recursive CTEs are unsupported") 410 expression.args["recursive"] = False 411 return self.with_sql(expression) 412 413 414def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 415 n = self.sql(expression, "this") 416 d = self.sql(expression, "expression") 417 return f"IF({d} <> 0, {n} / {d}, NULL)" 418 419 420def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 421 self.unsupported("TABLESAMPLE unsupported") 422 return self.sql(expression.this) 423 424 425def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 426 self.unsupported("PIVOT unsupported") 427 return "" 428 429 430def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 431 return self.cast_sql(expression) 432 433 434def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 435 self.unsupported("Properties unsupported") 436 return "" 437 438 439def no_comment_column_constraint_sql( 440 self: Generator, expression: exp.CommentColumnConstraint 441) -> str: 442 self.unsupported("CommentColumnConstraint unsupported") 443 return "" 444 445 446def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 447 self.unsupported("MAP_FROM_ENTRIES unsupported") 448 return "" 449 450 451def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 452 this = self.sql(expression, "this") 453 substr = self.sql(expression, "substr") 454 position = self.sql(expression, "position") 455 if position: 456 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 457 return f"STRPOS({this}, {substr})" 458 459 460def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 461 return ( 462 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 463 ) 464 465 466def var_map_sql( 467 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 468) -> str: 469 keys = expression.args["keys"] 470 values = expression.args["values"] 471 472 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 473 self.unsupported("Cannot convert array columns into map.") 474 return self.func(map_func_name, keys, values) 475 476 args = [] 477 for key, value in zip(keys.expressions, values.expressions): 478 args.append(self.sql(key)) 479 args.append(self.sql(value)) 480 481 return self.func(map_func_name, *args) 482 483 484def format_time_lambda( 485 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 486) -> t.Callable[[t.List], E]: 487 """Helper used for time expressions. 488 489 Args: 490 exp_class: the expression class to instantiate. 491 dialect: target sql dialect. 492 default: the default format, True being time. 493 494 Returns: 495 A callable that can be used to return the appropriately formatted time expression. 496 """ 497 498 def _format_time(args: t.List): 499 return exp_class( 500 this=seq_get(args, 0), 501 format=Dialect[dialect].format_time( 502 seq_get(args, 1) 503 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 504 ), 505 ) 506 507 return _format_time 508 509 510def time_format( 511 dialect: DialectType = None, 512) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 513 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 514 """ 515 Returns the time format for a given expression, unless it's equivalent 516 to the default time format of the dialect of interest. 517 """ 518 time_format = self.format_time(expression) 519 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 520 521 return _time_format 522 523 524def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 525 """ 526 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 527 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 528 columns are removed from the create statement. 529 """ 530 has_schema = isinstance(expression.this, exp.Schema) 531 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 532 533 if has_schema and is_partitionable: 534 prop = expression.find(exp.PartitionedByProperty) 535 if prop and prop.this and not isinstance(prop.this, exp.Schema): 536 schema = expression.this 537 columns = {v.name.upper() for v in prop.this.expressions} 538 partitions = [col for col in schema.expressions if col.name.upper() in columns] 539 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 540 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 541 expression.set("this", schema) 542 543 return self.create_sql(expression) 544 545 546def parse_date_delta( 547 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 548) -> t.Callable[[t.List], E]: 549 def inner_func(args: t.List) -> E: 550 unit_based = len(args) == 3 551 this = args[2] if unit_based else seq_get(args, 0) 552 unit = args[0] if unit_based else exp.Literal.string("DAY") 553 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 554 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 555 556 return inner_func 557 558 559def parse_date_delta_with_interval( 560 expression_class: t.Type[E], 561) -> t.Callable[[t.List], t.Optional[E]]: 562 def func(args: t.List) -> t.Optional[E]: 563 if len(args) < 2: 564 return None 565 566 interval = args[1] 567 568 if not isinstance(interval, exp.Interval): 569 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 570 571 expression = interval.this 572 if expression and expression.is_string: 573 expression = exp.Literal.number(expression.this) 574 575 return expression_class( 576 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 577 ) 578 579 return func 580 581 582def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 583 unit = seq_get(args, 0) 584 this = seq_get(args, 1) 585 586 if isinstance(this, exp.Cast) and this.is_type("date"): 587 return exp.DateTrunc(unit=unit, this=this) 588 return exp.TimestampTrunc(this=this, unit=unit) 589 590 591def date_add_interval_sql( 592 data_type: str, kind: str 593) -> t.Callable[[Generator, exp.Expression], str]: 594 def func(self: Generator, expression: exp.Expression) -> str: 595 this = self.sql(expression, "this") 596 unit = expression.args.get("unit") 597 unit = exp.var(unit.name.upper() if unit else "DAY") 598 interval = exp.Interval(this=expression.expression, unit=unit) 599 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 600 601 return func 602 603 604def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 605 return self.func( 606 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 607 ) 608 609 610def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 611 if not expression.expression: 612 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 613 if expression.text("expression").lower() in TIMEZONES: 614 return self.sql( 615 exp.AtTimeZone( 616 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 617 zone=expression.expression, 618 ) 619 ) 620 return self.function_fallback_sql(expression) 621 622 623def locate_to_strposition(args: t.List) -> exp.Expression: 624 return exp.StrPosition( 625 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 626 ) 627 628 629def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 630 return self.func( 631 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 632 ) 633 634 635def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 636 return self.sql( 637 exp.Substring( 638 this=expression.this, start=exp.Literal.number(1), length=expression.expression 639 ) 640 ) 641 642 643def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 644 return self.sql( 645 exp.Substring( 646 this=expression.this, 647 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 648 ) 649 ) 650 651 652def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 653 return self.sql(exp.cast(expression.this, "timestamp")) 654 655 656def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 657 return self.sql(exp.cast(expression.this, "date")) 658 659 660# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 661def encode_decode_sql( 662 self: Generator, expression: exp.Expression, name: str, replace: bool = True 663) -> str: 664 charset = expression.args.get("charset") 665 if charset and charset.name.lower() != "utf-8": 666 self.unsupported(f"Expected utf-8 character set, got {charset}.") 667 668 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 669 670 671def min_or_least(self: Generator, expression: exp.Min) -> str: 672 name = "LEAST" if expression.expressions else "MIN" 673 return rename_func(name)(self, expression) 674 675 676def max_or_greatest(self: Generator, expression: exp.Max) -> str: 677 name = "GREATEST" if expression.expressions else "MAX" 678 return rename_func(name)(self, expression) 679 680 681def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 682 cond = expression.this 683 684 if isinstance(expression.this, exp.Distinct): 685 cond = expression.this.expressions[0] 686 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 687 688 return self.func("sum", exp.func("if", cond, 1, 0)) 689 690 691def trim_sql(self: Generator, expression: exp.Trim) -> str: 692 target = self.sql(expression, "this") 693 trim_type = self.sql(expression, "position") 694 remove_chars = self.sql(expression, "expression") 695 collation = self.sql(expression, "collation") 696 697 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 698 if not remove_chars and not collation: 699 return self.trim_sql(expression) 700 701 trim_type = f"{trim_type} " if trim_type else "" 702 remove_chars = f"{remove_chars} " if remove_chars else "" 703 from_part = "FROM " if trim_type or remove_chars else "" 704 collation = f" COLLATE {collation}" if collation else "" 705 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 706 707 708def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 709 return self.func("STRPTIME", expression.this, self.format_time(expression)) 710 711 712def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 713 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 714 _dialect = Dialect.get_or_raise(dialect) 715 time_format = self.format_time(expression) 716 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 717 return self.sql( 718 exp.cast( 719 exp.StrToTime(this=expression.this, format=expression.args["format"]), 720 "date", 721 ) 722 ) 723 return self.sql(exp.cast(expression.this, "date")) 724 725 return _ts_or_ds_to_date_sql 726 727 728def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 729 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 730 731 732def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 733 delim, *rest_args = expression.expressions 734 return self.sql( 735 reduce( 736 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 737 rest_args, 738 ) 739 ) 740 741 742def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 743 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 744 if bad_args: 745 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 746 747 return self.func( 748 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 749 ) 750 751 752def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 753 bad_args = list( 754 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 755 ) 756 if bad_args: 757 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 758 759 return self.func( 760 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 761 ) 762 763 764def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 765 names = [] 766 for agg in aggregations: 767 if isinstance(agg, exp.Alias): 768 names.append(agg.alias) 769 else: 770 """ 771 This case corresponds to aggregations without aliases being used as suffixes 772 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 773 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 774 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 775 """ 776 agg_all_unquoted = agg.transform( 777 lambda node: exp.Identifier(this=node.name, quoted=False) 778 if isinstance(node, exp.Identifier) 779 else node 780 ) 781 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 782 783 return names 784 785 786def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 787 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 788 789 790# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 791def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 792 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 793 794 795def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 796 return self.func("MAX", expression.this) 797 798 799def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 800 a = self.sql(expression.left) 801 b = self.sql(expression.right) 802 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 803 804 805# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 806def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 807 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 808 809 810def is_parse_json(expression: exp.Expression) -> bool: 811 return isinstance(expression, exp.ParseJSON) or ( 812 isinstance(expression, exp.Cast) and expression.is_type("json") 813 ) 814 815 816def isnull_to_is_null(args: t.List) -> exp.Expression: 817 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 818 819 820def generatedasidentitycolumnconstraint_sql( 821 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 822) -> str: 823 start = self.sql(expression, "start") or "1" 824 increment = self.sql(expression, "increment") or "1" 825 return f"IDENTITY({start}, {increment})" 826 827 828def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 829 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 830 if expression.args.get("count"): 831 self.unsupported(f"Only two arguments are supported in function {name}.") 832 833 return self.func(name, expression.this, expression.expression) 834 835 return _arg_max_or_min_sql 836 837 838def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 839 this = expression.this.copy() 840 841 return_type = expression.return_type 842 if return_type.is_type(exp.DataType.Type.DATE): 843 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 844 # can truncate timestamp strings, because some dialects can't cast them to DATE 845 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 846 847 expression.this.replace(exp.cast(this, return_type)) 848 return expression 849 850 851def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 852 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 853 if cast and isinstance(expression, exp.TsOrDsAdd): 854 expression = ts_or_ds_add_cast(expression) 855 856 return self.func( 857 name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this 858 ) 859 860 return _delta_sql
24class Dialects(str, Enum): 25 DIALECT = "" 26 27 BIGQUERY = "bigquery" 28 CLICKHOUSE = "clickhouse" 29 DATABRICKS = "databricks" 30 DRILL = "drill" 31 DUCKDB = "duckdb" 32 HIVE = "hive" 33 MYSQL = "mysql" 34 ORACLE = "oracle" 35 POSTGRES = "postgres" 36 PRESTO = "presto" 37 REDSHIFT = "redshift" 38 SNOWFLAKE = "snowflake" 39 SPARK = "spark" 40 SPARK2 = "spark2" 41 SQLITE = "sqlite" 42 STARROCKS = "starrocks" 43 TABLEAU = "tableau" 44 TERADATA = "teradata" 45 TRINO = "trino" 46 TSQL = "tsql" 47 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
145class Dialect(metaclass=_Dialect): 146 # Determines the base index offset for arrays 147 INDEX_OFFSET = 0 148 149 # If true unnest table aliases are considered only as column aliases 150 UNNEST_COLUMN_ONLY = False 151 152 # Determines whether or not the table alias comes after tablesample 153 ALIAS_POST_TABLESAMPLE = False 154 155 # Determines whether or not unquoted identifiers are resolved as uppercase 156 # When set to None, it means that the dialect treats all identifiers as case-insensitive 157 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 158 159 # Determines whether or not an unquoted identifier can start with a digit 160 IDENTIFIERS_CAN_START_WITH_DIGIT = False 161 162 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 163 DPIPE_IS_STRING_CONCAT = True 164 165 # Determines whether or not CONCAT's arguments must be strings 166 STRICT_STRING_CONCAT = False 167 168 # Determines whether or not user-defined data types are supported 169 SUPPORTS_USER_DEFINED_TYPES = True 170 171 # Determines whether or not SEMI/ANTI JOINs are supported 172 SUPPORTS_SEMI_ANTI_JOIN = True 173 174 # Determines how function names are going to be normalized 175 NORMALIZE_FUNCTIONS: bool | str = "upper" 176 177 # Determines whether the base comes first in the LOG function 178 LOG_BASE_FIRST = True 179 180 # Indicates the default null ordering method to use if not explicitly set 181 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 182 NULL_ORDERING = "nulls_are_small" 183 184 # Whether the behavior of a / b depends on the types of a and b. 185 # False means a / b is always float division. 186 # True means a / b is integer division if both a and b are integers. 187 TYPED_DIVISION = False 188 189 # False means 1 / 0 throws an error. 190 # True means 1 / 0 returns null. 191 SAFE_DIVISION = False 192 193 DATE_FORMAT = "'%Y-%m-%d'" 194 DATEINT_FORMAT = "'%Y%m%d'" 195 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 196 197 # Custom time mappings in which the key represents dialect time format 198 # and the value represents a python time format 199 TIME_MAPPING: t.Dict[str, str] = {} 200 201 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 202 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 203 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 204 FORMAT_MAPPING: t.Dict[str, str] = {} 205 206 # Mapping of an unescaped escape sequence to the corresponding character 207 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 208 209 # Columns that are auto-generated by the engine corresponding to this dialect 210 # Such columns may be excluded from SELECT * queries, for example 211 PSEUDOCOLUMNS: t.Set[str] = set() 212 213 # Autofilled 214 tokenizer_class = Tokenizer 215 parser_class = Parser 216 generator_class = Generator 217 218 # A trie of the time_mapping keys 219 TIME_TRIE: t.Dict = {} 220 FORMAT_TRIE: t.Dict = {} 221 222 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 223 INVERSE_TIME_TRIE: t.Dict = {} 224 225 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 226 227 def __eq__(self, other: t.Any) -> bool: 228 return type(self) == other 229 230 def __hash__(self) -> int: 231 return hash(type(self)) 232 233 @classmethod 234 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 235 if not dialect: 236 return cls 237 if isinstance(dialect, _Dialect): 238 return dialect 239 if isinstance(dialect, Dialect): 240 return dialect.__class__ 241 242 result = cls.get(dialect) 243 if not result: 244 raise ValueError(f"Unknown dialect '{dialect}'") 245 246 return result 247 248 @classmethod 249 def format_time( 250 cls, expression: t.Optional[str | exp.Expression] 251 ) -> t.Optional[exp.Expression]: 252 if isinstance(expression, str): 253 return exp.Literal.string( 254 # the time formats are quoted 255 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 256 ) 257 258 if expression and expression.is_string: 259 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 260 261 return expression 262 263 @classmethod 264 def normalize_identifier(cls, expression: E) -> E: 265 """ 266 Normalizes an unquoted identifier to either lower or upper case, thus essentially 267 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 268 they will be normalized to lowercase regardless of being quoted or not. 269 """ 270 if isinstance(expression, exp.Identifier) and ( 271 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 272 ): 273 expression.set( 274 "this", 275 expression.this.upper() 276 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 277 else expression.this.lower(), 278 ) 279 280 return expression 281 282 @classmethod 283 def case_sensitive(cls, text: str) -> bool: 284 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 285 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 286 return False 287 288 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 289 return any(unsafe(char) for char in text) 290 291 @classmethod 292 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 293 """Checks if text can be identified given an identify option. 294 295 Args: 296 text: The text to check. 297 identify: 298 "always" or `True`: Always returns true. 299 "safe": True if the identifier is case-insensitive. 300 301 Returns: 302 Whether or not the given text can be identified. 303 """ 304 if identify is True or identify == "always": 305 return True 306 307 if identify == "safe": 308 return not cls.case_sensitive(text) 309 310 return False 311 312 @classmethod 313 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 314 if isinstance(expression, exp.Identifier): 315 name = expression.this 316 expression.set( 317 "quoted", 318 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 319 ) 320 321 return expression 322 323 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 324 return self.parser(**opts).parse(self.tokenize(sql), sql) 325 326 def parse_into( 327 self, expression_type: exp.IntoType, sql: str, **opts 328 ) -> t.List[t.Optional[exp.Expression]]: 329 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 330 331 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 332 return self.generator(**opts).generate(expression, copy=copy) 333 334 def transpile(self, sql: str, **opts) -> t.List[str]: 335 return [ 336 self.generate(expression, copy=False, **opts) if expression else "" 337 for expression in self.parse(sql) 338 ] 339 340 def tokenize(self, sql: str) -> t.List[Token]: 341 return self.tokenizer.tokenize(sql) 342 343 @property 344 def tokenizer(self) -> Tokenizer: 345 if not hasattr(self, "_tokenizer"): 346 self._tokenizer = self.tokenizer_class() 347 return self._tokenizer 348 349 def parser(self, **opts) -> Parser: 350 return self.parser_class(**opts) 351 352 def generator(self, **opts) -> Generator: 353 return self.generator_class(**opts)
233 @classmethod 234 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 235 if not dialect: 236 return cls 237 if isinstance(dialect, _Dialect): 238 return dialect 239 if isinstance(dialect, Dialect): 240 return dialect.__class__ 241 242 result = cls.get(dialect) 243 if not result: 244 raise ValueError(f"Unknown dialect '{dialect}'") 245 246 return result
248 @classmethod 249 def format_time( 250 cls, expression: t.Optional[str | exp.Expression] 251 ) -> t.Optional[exp.Expression]: 252 if isinstance(expression, str): 253 return exp.Literal.string( 254 # the time formats are quoted 255 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 256 ) 257 258 if expression and expression.is_string: 259 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 260 261 return expression
263 @classmethod 264 def normalize_identifier(cls, expression: E) -> E: 265 """ 266 Normalizes an unquoted identifier to either lower or upper case, thus essentially 267 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 268 they will be normalized to lowercase regardless of being quoted or not. 269 """ 270 if isinstance(expression, exp.Identifier) and ( 271 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 272 ): 273 expression.set( 274 "this", 275 expression.this.upper() 276 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 277 else expression.this.lower(), 278 ) 279 280 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.
282 @classmethod 283 def case_sensitive(cls, text: str) -> bool: 284 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 285 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 286 return False 287 288 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 289 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
291 @classmethod 292 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 293 """Checks if text can be identified given an identify option. 294 295 Args: 296 text: The text to check. 297 identify: 298 "always" or `True`: Always returns true. 299 "safe": True if the identifier is case-insensitive. 300 301 Returns: 302 Whether or not the given text can be identified. 303 """ 304 if identify is True or identify == "always": 305 return True 306 307 if identify == "safe": 308 return not cls.case_sensitive(text) 309 310 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
312 @classmethod 313 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 314 if isinstance(expression, exp.Identifier): 315 name = expression.this 316 expression.set( 317 "quoted", 318 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 319 ) 320 321 return expression
369def if_sql( 370 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 371) -> t.Callable[[Generator, exp.If], str]: 372 def _if_sql(self: Generator, expression: exp.If) -> str: 373 return self.func( 374 name, 375 expression.this, 376 expression.args.get("true"), 377 expression.args.get("false") or false_value, 378 ) 379 380 return _if_sql
452def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 453 this = self.sql(expression, "this") 454 substr = self.sql(expression, "substr") 455 position = self.sql(expression, "position") 456 if position: 457 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 458 return f"STRPOS({this}, {substr})"
467def var_map_sql( 468 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 469) -> str: 470 keys = expression.args["keys"] 471 values = expression.args["values"] 472 473 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 474 self.unsupported("Cannot convert array columns into map.") 475 return self.func(map_func_name, keys, values) 476 477 args = [] 478 for key, value in zip(keys.expressions, values.expressions): 479 args.append(self.sql(key)) 480 args.append(self.sql(value)) 481 482 return self.func(map_func_name, *args)
485def format_time_lambda( 486 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 487) -> t.Callable[[t.List], E]: 488 """Helper used for time expressions. 489 490 Args: 491 exp_class: the expression class to instantiate. 492 dialect: target sql dialect. 493 default: the default format, True being time. 494 495 Returns: 496 A callable that can be used to return the appropriately formatted time expression. 497 """ 498 499 def _format_time(args: t.List): 500 return exp_class( 501 this=seq_get(args, 0), 502 format=Dialect[dialect].format_time( 503 seq_get(args, 1) 504 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 505 ), 506 ) 507 508 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
511def time_format( 512 dialect: DialectType = None, 513) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 514 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 515 """ 516 Returns the time format for a given expression, unless it's equivalent 517 to the default time format of the dialect of interest. 518 """ 519 time_format = self.format_time(expression) 520 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 521 522 return _time_format
525def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 526 """ 527 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 528 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 529 columns are removed from the create statement. 530 """ 531 has_schema = isinstance(expression.this, exp.Schema) 532 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 533 534 if has_schema and is_partitionable: 535 prop = expression.find(exp.PartitionedByProperty) 536 if prop and prop.this and not isinstance(prop.this, exp.Schema): 537 schema = expression.this 538 columns = {v.name.upper() for v in prop.this.expressions} 539 partitions = [col for col in schema.expressions if col.name.upper() in columns] 540 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 541 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 542 expression.set("this", schema) 543 544 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
547def parse_date_delta( 548 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 549) -> t.Callable[[t.List], E]: 550 def inner_func(args: t.List) -> E: 551 unit_based = len(args) == 3 552 this = args[2] if unit_based else seq_get(args, 0) 553 unit = args[0] if unit_based else exp.Literal.string("DAY") 554 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 555 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 556 557 return inner_func
560def parse_date_delta_with_interval( 561 expression_class: t.Type[E], 562) -> t.Callable[[t.List], t.Optional[E]]: 563 def func(args: t.List) -> t.Optional[E]: 564 if len(args) < 2: 565 return None 566 567 interval = args[1] 568 569 if not isinstance(interval, exp.Interval): 570 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 571 572 expression = interval.this 573 if expression and expression.is_string: 574 expression = exp.Literal.number(expression.this) 575 576 return expression_class( 577 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 578 ) 579 580 return func
592def date_add_interval_sql( 593 data_type: str, kind: str 594) -> t.Callable[[Generator, exp.Expression], str]: 595 def func(self: Generator, expression: exp.Expression) -> str: 596 this = self.sql(expression, "this") 597 unit = expression.args.get("unit") 598 unit = exp.var(unit.name.upper() if unit else "DAY") 599 interval = exp.Interval(this=expression.expression, unit=unit) 600 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 601 602 return func
611def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 612 if not expression.expression: 613 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 614 if expression.text("expression").lower() in TIMEZONES: 615 return self.sql( 616 exp.AtTimeZone( 617 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 618 zone=expression.expression, 619 ) 620 ) 621 return self.function_fallback_sql(expression)
662def encode_decode_sql( 663 self: Generator, expression: exp.Expression, name: str, replace: bool = True 664) -> str: 665 charset = expression.args.get("charset") 666 if charset and charset.name.lower() != "utf-8": 667 self.unsupported(f"Expected utf-8 character set, got {charset}.") 668 669 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
682def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 683 cond = expression.this 684 685 if isinstance(expression.this, exp.Distinct): 686 cond = expression.this.expressions[0] 687 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 688 689 return self.func("sum", exp.func("if", cond, 1, 0))
692def trim_sql(self: Generator, expression: exp.Trim) -> str: 693 target = self.sql(expression, "this") 694 trim_type = self.sql(expression, "position") 695 remove_chars = self.sql(expression, "expression") 696 collation = self.sql(expression, "collation") 697 698 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 699 if not remove_chars and not collation: 700 return self.trim_sql(expression) 701 702 trim_type = f"{trim_type} " if trim_type else "" 703 remove_chars = f"{remove_chars} " if remove_chars else "" 704 from_part = "FROM " if trim_type or remove_chars else "" 705 collation = f" COLLATE {collation}" if collation else "" 706 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
713def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 714 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 715 _dialect = Dialect.get_or_raise(dialect) 716 time_format = self.format_time(expression) 717 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 718 return self.sql( 719 exp.cast( 720 exp.StrToTime(this=expression.this, format=expression.args["format"]), 721 "date", 722 ) 723 ) 724 return self.sql(exp.cast(expression.this, "date")) 725 726 return _ts_or_ds_to_date_sql
743def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 744 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 745 if bad_args: 746 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 747 748 return self.func( 749 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 750 )
753def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 754 bad_args = list( 755 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 756 ) 757 if bad_args: 758 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 759 760 return self.func( 761 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 762 )
765def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 766 names = [] 767 for agg in aggregations: 768 if isinstance(agg, exp.Alias): 769 names.append(agg.alias) 770 else: 771 """ 772 This case corresponds to aggregations without aliases being used as suffixes 773 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 774 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 775 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 776 """ 777 agg_all_unquoted = agg.transform( 778 lambda node: exp.Identifier(this=node.name, quoted=False) 779 if isinstance(node, exp.Identifier) 780 else node 781 ) 782 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 783 784 return names
829def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 830 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 831 if expression.args.get("count"): 832 self.unsupported(f"Only two arguments are supported in function {name}.") 833 834 return self.func(name, expression.this, expression.expression) 835 836 return _arg_max_or_min_sql
839def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 840 this = expression.this.copy() 841 842 return_type = expression.return_type 843 if return_type.is_type(exp.DataType.Type.DATE): 844 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 845 # can truncate timestamp strings, because some dialects can't cast them to DATE 846 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 847 848 expression.this.replace(exp.cast(this, return_type)) 849 return expression
852def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 853 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 854 if cast and isinstance(expression, exp.TsOrDsAdd): 855 expression = ts_or_ds_add_cast(expression) 856 857 return self.func( 858 name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this 859 ) 860 861 return _delta_sql