sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 base = seq_get(bases, 0) 114 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 115 base_parser = (getattr(base, "parser_class", Parser),) 116 base_generator = (getattr(base, "generator_class", Generator),) 117 118 klass.tokenizer_class = klass.__dict__.get( 119 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 120 ) 121 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 122 klass.generator_class = klass.__dict__.get( 123 "Generator", type("Generator", base_generator, {}) 124 ) 125 126 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 127 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 128 klass.tokenizer_class._IDENTIFIERS.items() 129 )[0] 130 131 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 132 return next( 133 ( 134 (s, e) 135 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 136 if t == token_type 137 ), 138 (None, None), 139 ) 140 141 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 142 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 143 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 144 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 145 146 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 147 klass.UNESCAPED_SEQUENCES = { 148 "\\a": "\a", 149 "\\b": "\b", 150 "\\f": "\f", 151 "\\n": "\n", 152 "\\r": "\r", 153 "\\t": "\t", 154 "\\v": "\v", 155 "\\\\": "\\", 156 **klass.UNESCAPED_SEQUENCES, 157 } 158 159 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 160 161 if enum not in ("", "bigquery"): 162 klass.generator_class.SELECT_KINDS = () 163 164 if enum not in ("", "athena", "presto", "trino"): 165 klass.generator_class.TRY_SUPPORTED = False 166 167 if enum not in ("", "databricks", "hive", "spark", "spark2"): 168 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 169 for modifier in ("cluster", "distribute", "sort"): 170 modifier_transforms.pop(modifier, None) 171 172 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 173 174 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 175 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 176 TokenType.ANTI, 177 TokenType.SEMI, 178 } 179 180 return klass 181 182 183class Dialect(metaclass=_Dialect): 184 INDEX_OFFSET = 0 185 """The base index offset for arrays.""" 186 187 WEEK_OFFSET = 0 188 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 189 190 UNNEST_COLUMN_ONLY = False 191 """Whether `UNNEST` table aliases are treated as column aliases.""" 192 193 ALIAS_POST_TABLESAMPLE = False 194 """Whether the table alias comes after tablesample.""" 195 196 TABLESAMPLE_SIZE_IS_PERCENT = False 197 """Whether a size in the table sample clause represents percentage.""" 198 199 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 200 """Specifies the strategy according to which identifiers should be normalized.""" 201 202 IDENTIFIERS_CAN_START_WITH_DIGIT = False 203 """Whether an unquoted identifier can start with a digit.""" 204 205 DPIPE_IS_STRING_CONCAT = True 206 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 207 208 STRICT_STRING_CONCAT = False 209 """Whether `CONCAT`'s arguments must be strings.""" 210 211 SUPPORTS_USER_DEFINED_TYPES = True 212 """Whether user-defined data types are supported.""" 213 214 SUPPORTS_SEMI_ANTI_JOIN = True 215 """Whether `SEMI` or `ANTI` joins are supported.""" 216 217 NORMALIZE_FUNCTIONS: bool | str = "upper" 218 """ 219 Determines how function names are going to be normalized. 220 Possible values: 221 "upper" or True: Convert names to uppercase. 222 "lower": Convert names to lowercase. 223 False: Disables function name normalization. 224 """ 225 226 LOG_BASE_FIRST: t.Optional[bool] = True 227 """ 228 Whether the base comes first in the `LOG` function. 229 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 230 """ 231 232 NULL_ORDERING = "nulls_are_small" 233 """ 234 Default `NULL` ordering method to use if not explicitly set. 235 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 236 """ 237 238 TYPED_DIVISION = False 239 """ 240 Whether the behavior of `a / b` depends on the types of `a` and `b`. 241 False means `a / b` is always float division. 242 True means `a / b` is integer division if both `a` and `b` are integers. 243 """ 244 245 SAFE_DIVISION = False 246 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 247 248 CONCAT_COALESCE = False 249 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 250 251 DATE_FORMAT = "'%Y-%m-%d'" 252 DATEINT_FORMAT = "'%Y%m%d'" 253 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 254 255 TIME_MAPPING: t.Dict[str, str] = {} 256 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 257 258 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 259 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 260 FORMAT_MAPPING: t.Dict[str, str] = {} 261 """ 262 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 263 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 264 """ 265 266 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 267 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 268 269 PSEUDOCOLUMNS: t.Set[str] = set() 270 """ 271 Columns that are auto-generated by the engine corresponding to this dialect. 272 For example, such columns may be excluded from `SELECT *` queries. 273 """ 274 275 PREFER_CTE_ALIAS_COLUMN = False 276 """ 277 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 278 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 279 any projection aliases in the subquery. 280 281 For example, 282 WITH y(c) AS ( 283 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 284 ) SELECT c FROM y; 285 286 will be rewritten as 287 288 WITH y(c) AS ( 289 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 290 ) SELECT c FROM y; 291 """ 292 293 # --- Autofilled --- 294 295 tokenizer_class = Tokenizer 296 parser_class = Parser 297 generator_class = Generator 298 299 # A trie of the time_mapping keys 300 TIME_TRIE: t.Dict = {} 301 FORMAT_TRIE: t.Dict = {} 302 303 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 304 INVERSE_TIME_TRIE: t.Dict = {} 305 306 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 307 308 # Delimiters for string literals and identifiers 309 QUOTE_START = "'" 310 QUOTE_END = "'" 311 IDENTIFIER_START = '"' 312 IDENTIFIER_END = '"' 313 314 # Delimiters for bit, hex, byte and unicode literals 315 BIT_START: t.Optional[str] = None 316 BIT_END: t.Optional[str] = None 317 HEX_START: t.Optional[str] = None 318 HEX_END: t.Optional[str] = None 319 BYTE_START: t.Optional[str] = None 320 BYTE_END: t.Optional[str] = None 321 UNICODE_START: t.Optional[str] = None 322 UNICODE_END: t.Optional[str] = None 323 324 # Separator of COPY statement parameters 325 COPY_PARAMS_ARE_CSV = True 326 327 @classmethod 328 def get_or_raise(cls, dialect: DialectType) -> Dialect: 329 """ 330 Look up a dialect in the global dialect registry and return it if it exists. 331 332 Args: 333 dialect: The target dialect. If this is a string, it can be optionally followed by 334 additional key-value pairs that are separated by commas and are used to specify 335 dialect settings, such as whether the dialect's identifiers are case-sensitive. 336 337 Example: 338 >>> dialect = dialect_class = get_or_raise("duckdb") 339 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 340 341 Returns: 342 The corresponding Dialect instance. 343 """ 344 345 if not dialect: 346 return cls() 347 if isinstance(dialect, _Dialect): 348 return dialect() 349 if isinstance(dialect, Dialect): 350 return dialect 351 if isinstance(dialect, str): 352 try: 353 dialect_name, *kv_pairs = dialect.split(",") 354 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 355 except ValueError: 356 raise ValueError( 357 f"Invalid dialect format: '{dialect}'. " 358 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 359 ) 360 361 result = cls.get(dialect_name.strip()) 362 if not result: 363 from difflib import get_close_matches 364 365 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 366 if similar: 367 similar = f" Did you mean {similar}?" 368 369 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 370 371 return result(**kwargs) 372 373 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 374 375 @classmethod 376 def format_time( 377 cls, expression: t.Optional[str | exp.Expression] 378 ) -> t.Optional[exp.Expression]: 379 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 380 if isinstance(expression, str): 381 return exp.Literal.string( 382 # the time formats are quoted 383 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 384 ) 385 386 if expression and expression.is_string: 387 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 388 389 return expression 390 391 def __init__(self, **kwargs) -> None: 392 normalization_strategy = kwargs.get("normalization_strategy") 393 394 if normalization_strategy is None: 395 self.normalization_strategy = self.NORMALIZATION_STRATEGY 396 else: 397 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 398 399 def __eq__(self, other: t.Any) -> bool: 400 # Does not currently take dialect state into account 401 return type(self) == other 402 403 def __hash__(self) -> int: 404 # Does not currently take dialect state into account 405 return hash(type(self)) 406 407 def normalize_identifier(self, expression: E) -> E: 408 """ 409 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 410 411 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 412 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 413 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 414 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 415 416 There are also dialects like Spark, which are case-insensitive even when quotes are 417 present, and dialects like MySQL, whose resolution rules match those employed by the 418 underlying operating system, for example they may always be case-sensitive in Linux. 419 420 Finally, the normalization behavior of some engines can even be controlled through flags, 421 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 422 423 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 424 that it can analyze queries in the optimizer and successfully capture their semantics. 425 """ 426 if ( 427 isinstance(expression, exp.Identifier) 428 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 429 and ( 430 not expression.quoted 431 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 432 ) 433 ): 434 expression.set( 435 "this", 436 ( 437 expression.this.upper() 438 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 439 else expression.this.lower() 440 ), 441 ) 442 443 return expression 444 445 def case_sensitive(self, text: str) -> bool: 446 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 447 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 448 return False 449 450 unsafe = ( 451 str.islower 452 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 453 else str.isupper 454 ) 455 return any(unsafe(char) for char in text) 456 457 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 458 """Checks if text can be identified given an identify option. 459 460 Args: 461 text: The text to check. 462 identify: 463 `"always"` or `True`: Always returns `True`. 464 `"safe"`: Only returns `True` if the identifier is case-insensitive. 465 466 Returns: 467 Whether the given text can be identified. 468 """ 469 if identify is True or identify == "always": 470 return True 471 472 if identify == "safe": 473 return not self.case_sensitive(text) 474 475 return False 476 477 def quote_identifier(self, expression: E, identify: bool = True) -> E: 478 """ 479 Adds quotes to a given identifier. 480 481 Args: 482 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 483 identify: If set to `False`, the quotes will only be added if the identifier is deemed 484 "unsafe", with respect to its characters and this dialect's normalization strategy. 485 """ 486 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 487 name = expression.this 488 expression.set( 489 "quoted", 490 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 491 ) 492 493 return expression 494 495 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 496 if isinstance(path, exp.Literal): 497 path_text = path.name 498 if path.is_number: 499 path_text = f"[{path_text}]" 500 501 try: 502 return parse_json_path(path_text) 503 except ParseError as e: 504 logger.warning(f"Invalid JSON path syntax. {str(e)}") 505 506 return path 507 508 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 509 return self.parser(**opts).parse(self.tokenize(sql), sql) 510 511 def parse_into( 512 self, expression_type: exp.IntoType, sql: str, **opts 513 ) -> t.List[t.Optional[exp.Expression]]: 514 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 515 516 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 517 return self.generator(**opts).generate(expression, copy=copy) 518 519 def transpile(self, sql: str, **opts) -> t.List[str]: 520 return [ 521 self.generate(expression, copy=False, **opts) if expression else "" 522 for expression in self.parse(sql) 523 ] 524 525 def tokenize(self, sql: str) -> t.List[Token]: 526 return self.tokenizer.tokenize(sql) 527 528 @property 529 def tokenizer(self) -> Tokenizer: 530 if not hasattr(self, "_tokenizer"): 531 self._tokenizer = self.tokenizer_class(dialect=self) 532 return self._tokenizer 533 534 def parser(self, **opts) -> Parser: 535 return self.parser_class(dialect=self, **opts) 536 537 def generator(self, **opts) -> Generator: 538 return self.generator_class(dialect=self, **opts) 539 540 541DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 542 543 544def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 545 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 546 547 548def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 549 if expression.args.get("accuracy"): 550 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 551 return self.func("APPROX_COUNT_DISTINCT", expression.this) 552 553 554def if_sql( 555 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 556) -> t.Callable[[Generator, exp.If], str]: 557 def _if_sql(self: Generator, expression: exp.If) -> str: 558 return self.func( 559 name, 560 expression.this, 561 expression.args.get("true"), 562 expression.args.get("false") or false_value, 563 ) 564 565 return _if_sql 566 567 568def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 569 this = expression.this 570 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 571 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 572 573 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 574 575 576def inline_array_sql(self: Generator, expression: exp.Array) -> str: 577 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 578 579 580def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 581 elem = seq_get(expression.expressions, 0) 582 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 583 return self.func("ARRAY", elem) 584 return inline_array_sql(self, expression) 585 586 587def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 588 return self.like_sql( 589 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 590 ) 591 592 593def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 594 zone = self.sql(expression, "this") 595 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 596 597 598def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 599 if expression.args.get("recursive"): 600 self.unsupported("Recursive CTEs are unsupported") 601 expression.args["recursive"] = False 602 return self.with_sql(expression) 603 604 605def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 606 n = self.sql(expression, "this") 607 d = self.sql(expression, "expression") 608 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 609 610 611def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 612 self.unsupported("TABLESAMPLE unsupported") 613 return self.sql(expression.this) 614 615 616def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 617 self.unsupported("PIVOT unsupported") 618 return "" 619 620 621def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 622 return self.cast_sql(expression) 623 624 625def no_comment_column_constraint_sql( 626 self: Generator, expression: exp.CommentColumnConstraint 627) -> str: 628 self.unsupported("CommentColumnConstraint unsupported") 629 return "" 630 631 632def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 633 self.unsupported("MAP_FROM_ENTRIES unsupported") 634 return "" 635 636 637def str_position_sql( 638 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 639) -> str: 640 this = self.sql(expression, "this") 641 substr = self.sql(expression, "substr") 642 position = self.sql(expression, "position") 643 instance = expression.args.get("instance") if generate_instance else None 644 position_offset = "" 645 646 if position: 647 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 648 this = self.func("SUBSTR", this, position) 649 position_offset = f" + {position} - 1" 650 651 return self.func("STRPOS", this, substr, instance) + position_offset 652 653 654def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 655 return ( 656 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 657 ) 658 659 660def var_map_sql( 661 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 662) -> str: 663 keys = expression.args["keys"] 664 values = expression.args["values"] 665 666 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 667 self.unsupported("Cannot convert array columns into map.") 668 return self.func(map_func_name, keys, values) 669 670 args = [] 671 for key, value in zip(keys.expressions, values.expressions): 672 args.append(self.sql(key)) 673 args.append(self.sql(value)) 674 675 return self.func(map_func_name, *args) 676 677 678def build_formatted_time( 679 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 680) -> t.Callable[[t.List], E]: 681 """Helper used for time expressions. 682 683 Args: 684 exp_class: the expression class to instantiate. 685 dialect: target sql dialect. 686 default: the default format, True being time. 687 688 Returns: 689 A callable that can be used to return the appropriately formatted time expression. 690 """ 691 692 def _builder(args: t.List): 693 return exp_class( 694 this=seq_get(args, 0), 695 format=Dialect[dialect].format_time( 696 seq_get(args, 1) 697 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 698 ), 699 ) 700 701 return _builder 702 703 704def time_format( 705 dialect: DialectType = None, 706) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 707 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 708 """ 709 Returns the time format for a given expression, unless it's equivalent 710 to the default time format of the dialect of interest. 711 """ 712 time_format = self.format_time(expression) 713 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 714 715 return _time_format 716 717 718def build_date_delta( 719 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 720) -> t.Callable[[t.List], E]: 721 def _builder(args: t.List) -> E: 722 unit_based = len(args) == 3 723 this = args[2] if unit_based else seq_get(args, 0) 724 unit = args[0] if unit_based else exp.Literal.string("DAY") 725 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 726 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 727 728 return _builder 729 730 731def build_date_delta_with_interval( 732 expression_class: t.Type[E], 733) -> t.Callable[[t.List], t.Optional[E]]: 734 def _builder(args: t.List) -> t.Optional[E]: 735 if len(args) < 2: 736 return None 737 738 interval = args[1] 739 740 if not isinstance(interval, exp.Interval): 741 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 742 743 expression = interval.this 744 if expression and expression.is_string: 745 expression = exp.Literal.number(expression.this) 746 747 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 748 749 return _builder 750 751 752def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 753 unit = seq_get(args, 0) 754 this = seq_get(args, 1) 755 756 if isinstance(this, exp.Cast) and this.is_type("date"): 757 return exp.DateTrunc(unit=unit, this=this) 758 return exp.TimestampTrunc(this=this, unit=unit) 759 760 761def date_add_interval_sql( 762 data_type: str, kind: str 763) -> t.Callable[[Generator, exp.Expression], str]: 764 def func(self: Generator, expression: exp.Expression) -> str: 765 this = self.sql(expression, "this") 766 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 767 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 768 769 return func 770 771 772def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 773 return self.func("DATE_TRUNC", unit_to_str(expression), expression.this) 774 775 776def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 777 if not expression.expression: 778 from sqlglot.optimizer.annotate_types import annotate_types 779 780 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 781 return self.sql(exp.cast(expression.this, target_type)) 782 if expression.text("expression").lower() in TIMEZONES: 783 return self.sql( 784 exp.AtTimeZone( 785 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 786 zone=expression.expression, 787 ) 788 ) 789 return self.func("TIMESTAMP", expression.this, expression.expression) 790 791 792def locate_to_strposition(args: t.List) -> exp.Expression: 793 return exp.StrPosition( 794 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 795 ) 796 797 798def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 799 return self.func( 800 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 801 ) 802 803 804def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 805 return self.sql( 806 exp.Substring( 807 this=expression.this, start=exp.Literal.number(1), length=expression.expression 808 ) 809 ) 810 811 812def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 813 return self.sql( 814 exp.Substring( 815 this=expression.this, 816 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 817 ) 818 ) 819 820 821def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 822 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 823 824 825def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 826 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 827 828 829# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 830def encode_decode_sql( 831 self: Generator, expression: exp.Expression, name: str, replace: bool = True 832) -> str: 833 charset = expression.args.get("charset") 834 if charset and charset.name.lower() != "utf-8": 835 self.unsupported(f"Expected utf-8 character set, got {charset}.") 836 837 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 838 839 840def min_or_least(self: Generator, expression: exp.Min) -> str: 841 name = "LEAST" if expression.expressions else "MIN" 842 return rename_func(name)(self, expression) 843 844 845def max_or_greatest(self: Generator, expression: exp.Max) -> str: 846 name = "GREATEST" if expression.expressions else "MAX" 847 return rename_func(name)(self, expression) 848 849 850def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 851 cond = expression.this 852 853 if isinstance(expression.this, exp.Distinct): 854 cond = expression.this.expressions[0] 855 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 856 857 return self.func("sum", exp.func("if", cond, 1, 0)) 858 859 860def trim_sql(self: Generator, expression: exp.Trim) -> str: 861 target = self.sql(expression, "this") 862 trim_type = self.sql(expression, "position") 863 remove_chars = self.sql(expression, "expression") 864 collation = self.sql(expression, "collation") 865 866 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 867 if not remove_chars and not collation: 868 return self.trim_sql(expression) 869 870 trim_type = f"{trim_type} " if trim_type else "" 871 remove_chars = f"{remove_chars} " if remove_chars else "" 872 from_part = "FROM " if trim_type or remove_chars else "" 873 collation = f" COLLATE {collation}" if collation else "" 874 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 875 876 877def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 878 return self.func("STRPTIME", expression.this, self.format_time(expression)) 879 880 881def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 882 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 883 884 885def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 886 delim, *rest_args = expression.expressions 887 return self.sql( 888 reduce( 889 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 890 rest_args, 891 ) 892 ) 893 894 895def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 896 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 897 if bad_args: 898 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 899 900 return self.func( 901 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 902 ) 903 904 905def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 906 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 907 if bad_args: 908 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 909 910 return self.func( 911 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 912 ) 913 914 915def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 916 names = [] 917 for agg in aggregations: 918 if isinstance(agg, exp.Alias): 919 names.append(agg.alias) 920 else: 921 """ 922 This case corresponds to aggregations without aliases being used as suffixes 923 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 924 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 925 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 926 """ 927 agg_all_unquoted = agg.transform( 928 lambda node: ( 929 exp.Identifier(this=node.name, quoted=False) 930 if isinstance(node, exp.Identifier) 931 else node 932 ) 933 ) 934 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 935 936 return names 937 938 939def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 940 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 941 942 943# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 944def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 945 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 946 947 948def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 949 return self.func("MAX", expression.this) 950 951 952def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 953 a = self.sql(expression.left) 954 b = self.sql(expression.right) 955 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 956 957 958def is_parse_json(expression: exp.Expression) -> bool: 959 return isinstance(expression, exp.ParseJSON) or ( 960 isinstance(expression, exp.Cast) and expression.is_type("json") 961 ) 962 963 964def isnull_to_is_null(args: t.List) -> exp.Expression: 965 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 966 967 968def generatedasidentitycolumnconstraint_sql( 969 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 970) -> str: 971 start = self.sql(expression, "start") or "1" 972 increment = self.sql(expression, "increment") or "1" 973 return f"IDENTITY({start}, {increment})" 974 975 976def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 977 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 978 if expression.args.get("count"): 979 self.unsupported(f"Only two arguments are supported in function {name}.") 980 981 return self.func(name, expression.this, expression.expression) 982 983 return _arg_max_or_min_sql 984 985 986def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 987 this = expression.this.copy() 988 989 return_type = expression.return_type 990 if return_type.is_type(exp.DataType.Type.DATE): 991 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 992 # can truncate timestamp strings, because some dialects can't cast them to DATE 993 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 994 995 expression.this.replace(exp.cast(this, return_type)) 996 return expression 997 998 999def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1000 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1001 if cast and isinstance(expression, exp.TsOrDsAdd): 1002 expression = ts_or_ds_add_cast(expression) 1003 1004 return self.func( 1005 name, 1006 unit_to_var(expression), 1007 expression.expression, 1008 expression.this, 1009 ) 1010 1011 return _delta_sql 1012 1013 1014def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1015 unit = expression.args.get("unit") 1016 1017 if isinstance(unit, exp.Placeholder): 1018 return unit 1019 if unit: 1020 return exp.Literal.string(unit.name) 1021 return exp.Literal.string(default) if default else None 1022 1023 1024def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1025 unit = expression.args.get("unit") 1026 1027 if isinstance(unit, (exp.Var, exp.Placeholder)): 1028 return unit 1029 return exp.Var(this=default) if default else None 1030 1031 1032def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1033 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1034 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1035 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1036 1037 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1038 1039 1040def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1041 """Remove table refs from columns in when statements.""" 1042 alias = expression.this.args.get("alias") 1043 1044 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1045 return self.dialect.normalize_identifier(identifier).name if identifier else None 1046 1047 targets = {normalize(expression.this.this)} 1048 1049 if alias: 1050 targets.add(normalize(alias.this)) 1051 1052 for when in expression.expressions: 1053 when.transform( 1054 lambda node: ( 1055 exp.column(node.this) 1056 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1057 else node 1058 ), 1059 copy=False, 1060 ) 1061 1062 return self.merge_sql(expression) 1063 1064 1065def build_json_extract_path( 1066 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1067) -> t.Callable[[t.List], F]: 1068 def _builder(args: t.List) -> F: 1069 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1070 for arg in args[1:]: 1071 if not isinstance(arg, exp.Literal): 1072 # We use the fallback parser because we can't really transpile non-literals safely 1073 return expr_type.from_arg_list(args) 1074 1075 text = arg.name 1076 if is_int(text): 1077 index = int(text) 1078 segments.append( 1079 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1080 ) 1081 else: 1082 segments.append(exp.JSONPathKey(this=text)) 1083 1084 # This is done to avoid failing in the expression validator due to the arg count 1085 del args[2:] 1086 return expr_type( 1087 this=seq_get(args, 0), 1088 expression=exp.JSONPath(expressions=segments), 1089 only_json_types=arrow_req_json_type, 1090 ) 1091 1092 return _builder 1093 1094 1095def json_extract_segments( 1096 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1097) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1098 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1099 path = expression.expression 1100 if not isinstance(path, exp.JSONPath): 1101 return rename_func(name)(self, expression) 1102 1103 segments = [] 1104 for segment in path.expressions: 1105 path = self.sql(segment) 1106 if path: 1107 if isinstance(segment, exp.JSONPathPart) and ( 1108 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1109 ): 1110 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1111 1112 segments.append(path) 1113 1114 if op: 1115 return f" {op} ".join([self.sql(expression.this), *segments]) 1116 return self.func(name, expression.this, *segments) 1117 1118 return _json_extract_segments 1119 1120 1121def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1122 if isinstance(expression.this, exp.JSONPathWildcard): 1123 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1124 1125 return expression.name 1126 1127 1128def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1129 cond = expression.expression 1130 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1131 alias = cond.expressions[0] 1132 cond = cond.this 1133 elif isinstance(cond, exp.Predicate): 1134 alias = "_u" 1135 else: 1136 self.unsupported("Unsupported filter condition") 1137 return "" 1138 1139 unnest = exp.Unnest(expressions=[expression.this]) 1140 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1141 return self.sql(exp.Array(expressions=[filtered])) 1142 1143 1144def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: 1145 return self.func( 1146 "TO_NUMBER", 1147 expression.this, 1148 expression.args.get("format"), 1149 expression.args.get("nlsparam"), 1150 )
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
184class Dialect(metaclass=_Dialect): 185 INDEX_OFFSET = 0 186 """The base index offset for arrays.""" 187 188 WEEK_OFFSET = 0 189 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 190 191 UNNEST_COLUMN_ONLY = False 192 """Whether `UNNEST` table aliases are treated as column aliases.""" 193 194 ALIAS_POST_TABLESAMPLE = False 195 """Whether the table alias comes after tablesample.""" 196 197 TABLESAMPLE_SIZE_IS_PERCENT = False 198 """Whether a size in the table sample clause represents percentage.""" 199 200 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 201 """Specifies the strategy according to which identifiers should be normalized.""" 202 203 IDENTIFIERS_CAN_START_WITH_DIGIT = False 204 """Whether an unquoted identifier can start with a digit.""" 205 206 DPIPE_IS_STRING_CONCAT = True 207 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 208 209 STRICT_STRING_CONCAT = False 210 """Whether `CONCAT`'s arguments must be strings.""" 211 212 SUPPORTS_USER_DEFINED_TYPES = True 213 """Whether user-defined data types are supported.""" 214 215 SUPPORTS_SEMI_ANTI_JOIN = True 216 """Whether `SEMI` or `ANTI` joins are supported.""" 217 218 NORMALIZE_FUNCTIONS: bool | str = "upper" 219 """ 220 Determines how function names are going to be normalized. 221 Possible values: 222 "upper" or True: Convert names to uppercase. 223 "lower": Convert names to lowercase. 224 False: Disables function name normalization. 225 """ 226 227 LOG_BASE_FIRST: t.Optional[bool] = True 228 """ 229 Whether the base comes first in the `LOG` function. 230 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 231 """ 232 233 NULL_ORDERING = "nulls_are_small" 234 """ 235 Default `NULL` ordering method to use if not explicitly set. 236 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 237 """ 238 239 TYPED_DIVISION = False 240 """ 241 Whether the behavior of `a / b` depends on the types of `a` and `b`. 242 False means `a / b` is always float division. 243 True means `a / b` is integer division if both `a` and `b` are integers. 244 """ 245 246 SAFE_DIVISION = False 247 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 248 249 CONCAT_COALESCE = False 250 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 251 252 DATE_FORMAT = "'%Y-%m-%d'" 253 DATEINT_FORMAT = "'%Y%m%d'" 254 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 255 256 TIME_MAPPING: t.Dict[str, str] = {} 257 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 258 259 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 260 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 261 FORMAT_MAPPING: t.Dict[str, str] = {} 262 """ 263 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 264 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 265 """ 266 267 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 268 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 269 270 PSEUDOCOLUMNS: t.Set[str] = set() 271 """ 272 Columns that are auto-generated by the engine corresponding to this dialect. 273 For example, such columns may be excluded from `SELECT *` queries. 274 """ 275 276 PREFER_CTE_ALIAS_COLUMN = False 277 """ 278 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 279 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 280 any projection aliases in the subquery. 281 282 For example, 283 WITH y(c) AS ( 284 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 285 ) SELECT c FROM y; 286 287 will be rewritten as 288 289 WITH y(c) AS ( 290 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 291 ) SELECT c FROM y; 292 """ 293 294 # --- Autofilled --- 295 296 tokenizer_class = Tokenizer 297 parser_class = Parser 298 generator_class = Generator 299 300 # A trie of the time_mapping keys 301 TIME_TRIE: t.Dict = {} 302 FORMAT_TRIE: t.Dict = {} 303 304 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 305 INVERSE_TIME_TRIE: t.Dict = {} 306 307 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 308 309 # Delimiters for string literals and identifiers 310 QUOTE_START = "'" 311 QUOTE_END = "'" 312 IDENTIFIER_START = '"' 313 IDENTIFIER_END = '"' 314 315 # Delimiters for bit, hex, byte and unicode literals 316 BIT_START: t.Optional[str] = None 317 BIT_END: t.Optional[str] = None 318 HEX_START: t.Optional[str] = None 319 HEX_END: t.Optional[str] = None 320 BYTE_START: t.Optional[str] = None 321 BYTE_END: t.Optional[str] = None 322 UNICODE_START: t.Optional[str] = None 323 UNICODE_END: t.Optional[str] = None 324 325 # Separator of COPY statement parameters 326 COPY_PARAMS_ARE_CSV = True 327 328 @classmethod 329 def get_or_raise(cls, dialect: DialectType) -> Dialect: 330 """ 331 Look up a dialect in the global dialect registry and return it if it exists. 332 333 Args: 334 dialect: The target dialect. If this is a string, it can be optionally followed by 335 additional key-value pairs that are separated by commas and are used to specify 336 dialect settings, such as whether the dialect's identifiers are case-sensitive. 337 338 Example: 339 >>> dialect = dialect_class = get_or_raise("duckdb") 340 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 341 342 Returns: 343 The corresponding Dialect instance. 344 """ 345 346 if not dialect: 347 return cls() 348 if isinstance(dialect, _Dialect): 349 return dialect() 350 if isinstance(dialect, Dialect): 351 return dialect 352 if isinstance(dialect, str): 353 try: 354 dialect_name, *kv_pairs = dialect.split(",") 355 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 356 except ValueError: 357 raise ValueError( 358 f"Invalid dialect format: '{dialect}'. " 359 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 360 ) 361 362 result = cls.get(dialect_name.strip()) 363 if not result: 364 from difflib import get_close_matches 365 366 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 367 if similar: 368 similar = f" Did you mean {similar}?" 369 370 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 371 372 return result(**kwargs) 373 374 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 375 376 @classmethod 377 def format_time( 378 cls, expression: t.Optional[str | exp.Expression] 379 ) -> t.Optional[exp.Expression]: 380 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 381 if isinstance(expression, str): 382 return exp.Literal.string( 383 # the time formats are quoted 384 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 385 ) 386 387 if expression and expression.is_string: 388 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 389 390 return expression 391 392 def __init__(self, **kwargs) -> None: 393 normalization_strategy = kwargs.get("normalization_strategy") 394 395 if normalization_strategy is None: 396 self.normalization_strategy = self.NORMALIZATION_STRATEGY 397 else: 398 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 399 400 def __eq__(self, other: t.Any) -> bool: 401 # Does not currently take dialect state into account 402 return type(self) == other 403 404 def __hash__(self) -> int: 405 # Does not currently take dialect state into account 406 return hash(type(self)) 407 408 def normalize_identifier(self, expression: E) -> E: 409 """ 410 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 411 412 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 413 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 414 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 415 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 416 417 There are also dialects like Spark, which are case-insensitive even when quotes are 418 present, and dialects like MySQL, whose resolution rules match those employed by the 419 underlying operating system, for example they may always be case-sensitive in Linux. 420 421 Finally, the normalization behavior of some engines can even be controlled through flags, 422 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 423 424 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 425 that it can analyze queries in the optimizer and successfully capture their semantics. 426 """ 427 if ( 428 isinstance(expression, exp.Identifier) 429 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 430 and ( 431 not expression.quoted 432 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 433 ) 434 ): 435 expression.set( 436 "this", 437 ( 438 expression.this.upper() 439 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 440 else expression.this.lower() 441 ), 442 ) 443 444 return expression 445 446 def case_sensitive(self, text: str) -> bool: 447 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 448 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 449 return False 450 451 unsafe = ( 452 str.islower 453 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 454 else str.isupper 455 ) 456 return any(unsafe(char) for char in text) 457 458 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 459 """Checks if text can be identified given an identify option. 460 461 Args: 462 text: The text to check. 463 identify: 464 `"always"` or `True`: Always returns `True`. 465 `"safe"`: Only returns `True` if the identifier is case-insensitive. 466 467 Returns: 468 Whether the given text can be identified. 469 """ 470 if identify is True or identify == "always": 471 return True 472 473 if identify == "safe": 474 return not self.case_sensitive(text) 475 476 return False 477 478 def quote_identifier(self, expression: E, identify: bool = True) -> E: 479 """ 480 Adds quotes to a given identifier. 481 482 Args: 483 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 484 identify: If set to `False`, the quotes will only be added if the identifier is deemed 485 "unsafe", with respect to its characters and this dialect's normalization strategy. 486 """ 487 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 488 name = expression.this 489 expression.set( 490 "quoted", 491 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 492 ) 493 494 return expression 495 496 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 497 if isinstance(path, exp.Literal): 498 path_text = path.name 499 if path.is_number: 500 path_text = f"[{path_text}]" 501 502 try: 503 return parse_json_path(path_text) 504 except ParseError as e: 505 logger.warning(f"Invalid JSON path syntax. {str(e)}") 506 507 return path 508 509 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 510 return self.parser(**opts).parse(self.tokenize(sql), sql) 511 512 def parse_into( 513 self, expression_type: exp.IntoType, sql: str, **opts 514 ) -> t.List[t.Optional[exp.Expression]]: 515 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 516 517 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 518 return self.generator(**opts).generate(expression, copy=copy) 519 520 def transpile(self, sql: str, **opts) -> t.List[str]: 521 return [ 522 self.generate(expression, copy=False, **opts) if expression else "" 523 for expression in self.parse(sql) 524 ] 525 526 def tokenize(self, sql: str) -> t.List[Token]: 527 return self.tokenizer.tokenize(sql) 528 529 @property 530 def tokenizer(self) -> Tokenizer: 531 if not hasattr(self, "_tokenizer"): 532 self._tokenizer = self.tokenizer_class(dialect=self) 533 return self._tokenizer 534 535 def parser(self, **opts) -> Parser: 536 return self.parser_class(dialect=self, **opts) 537 538 def generator(self, **opts) -> Generator: 539 return self.generator_class(dialect=self, **opts)
392 def __init__(self, **kwargs) -> None: 393 normalization_strategy = kwargs.get("normalization_strategy") 394 395 if normalization_strategy is None: 396 self.normalization_strategy = self.NORMALIZATION_STRATEGY 397 else: 398 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
328 @classmethod 329 def get_or_raise(cls, dialect: DialectType) -> Dialect: 330 """ 331 Look up a dialect in the global dialect registry and return it if it exists. 332 333 Args: 334 dialect: The target dialect. If this is a string, it can be optionally followed by 335 additional key-value pairs that are separated by commas and are used to specify 336 dialect settings, such as whether the dialect's identifiers are case-sensitive. 337 338 Example: 339 >>> dialect = dialect_class = get_or_raise("duckdb") 340 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 341 342 Returns: 343 The corresponding Dialect instance. 344 """ 345 346 if not dialect: 347 return cls() 348 if isinstance(dialect, _Dialect): 349 return dialect() 350 if isinstance(dialect, Dialect): 351 return dialect 352 if isinstance(dialect, str): 353 try: 354 dialect_name, *kv_pairs = dialect.split(",") 355 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 356 except ValueError: 357 raise ValueError( 358 f"Invalid dialect format: '{dialect}'. " 359 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 360 ) 361 362 result = cls.get(dialect_name.strip()) 363 if not result: 364 from difflib import get_close_matches 365 366 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 367 if similar: 368 similar = f" Did you mean {similar}?" 369 370 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 371 372 return result(**kwargs) 373 374 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
376 @classmethod 377 def format_time( 378 cls, expression: t.Optional[str | exp.Expression] 379 ) -> t.Optional[exp.Expression]: 380 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 381 if isinstance(expression, str): 382 return exp.Literal.string( 383 # the time formats are quoted 384 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 385 ) 386 387 if expression and expression.is_string: 388 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 389 390 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
408 def normalize_identifier(self, expression: E) -> E: 409 """ 410 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 411 412 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 413 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 414 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 415 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 416 417 There are also dialects like Spark, which are case-insensitive even when quotes are 418 present, and dialects like MySQL, whose resolution rules match those employed by the 419 underlying operating system, for example they may always be case-sensitive in Linux. 420 421 Finally, the normalization behavior of some engines can even be controlled through flags, 422 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 423 424 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 425 that it can analyze queries in the optimizer and successfully capture their semantics. 426 """ 427 if ( 428 isinstance(expression, exp.Identifier) 429 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 430 and ( 431 not expression.quoted 432 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 433 ) 434 ): 435 expression.set( 436 "this", 437 ( 438 expression.this.upper() 439 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 440 else expression.this.lower() 441 ), 442 ) 443 444 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
446 def case_sensitive(self, text: str) -> bool: 447 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 448 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 449 return False 450 451 unsafe = ( 452 str.islower 453 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 454 else str.isupper 455 ) 456 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
458 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 459 """Checks if text can be identified given an identify option. 460 461 Args: 462 text: The text to check. 463 identify: 464 `"always"` or `True`: Always returns `True`. 465 `"safe"`: Only returns `True` if the identifier is case-insensitive. 466 467 Returns: 468 Whether the given text can be identified. 469 """ 470 if identify is True or identify == "always": 471 return True 472 473 if identify == "safe": 474 return not self.case_sensitive(text) 475 476 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
478 def quote_identifier(self, expression: E, identify: bool = True) -> E: 479 """ 480 Adds quotes to a given identifier. 481 482 Args: 483 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 484 identify: If set to `False`, the quotes will only be added if the identifier is deemed 485 "unsafe", with respect to its characters and this dialect's normalization strategy. 486 """ 487 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 488 name = expression.this 489 expression.set( 490 "quoted", 491 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 492 ) 493 494 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
496 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 497 if isinstance(path, exp.Literal): 498 path_text = path.name 499 if path.is_number: 500 path_text = f"[{path_text}]" 501 502 try: 503 return parse_json_path(path_text) 504 except ParseError as e: 505 logger.warning(f"Invalid JSON path syntax. {str(e)}") 506 507 return path
555def if_sql( 556 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 557) -> t.Callable[[Generator, exp.If], str]: 558 def _if_sql(self: Generator, expression: exp.If) -> str: 559 return self.func( 560 name, 561 expression.this, 562 expression.args.get("true"), 563 expression.args.get("false") or false_value, 564 ) 565 566 return _if_sql
569def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 570 this = expression.this 571 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 572 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 573 574 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
638def str_position_sql( 639 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 640) -> str: 641 this = self.sql(expression, "this") 642 substr = self.sql(expression, "substr") 643 position = self.sql(expression, "position") 644 instance = expression.args.get("instance") if generate_instance else None 645 position_offset = "" 646 647 if position: 648 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 649 this = self.func("SUBSTR", this, position) 650 position_offset = f" + {position} - 1" 651 652 return self.func("STRPOS", this, substr, instance) + position_offset
661def var_map_sql( 662 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 663) -> str: 664 keys = expression.args["keys"] 665 values = expression.args["values"] 666 667 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 668 self.unsupported("Cannot convert array columns into map.") 669 return self.func(map_func_name, keys, values) 670 671 args = [] 672 for key, value in zip(keys.expressions, values.expressions): 673 args.append(self.sql(key)) 674 args.append(self.sql(value)) 675 676 return self.func(map_func_name, *args)
679def build_formatted_time( 680 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 681) -> t.Callable[[t.List], E]: 682 """Helper used for time expressions. 683 684 Args: 685 exp_class: the expression class to instantiate. 686 dialect: target sql dialect. 687 default: the default format, True being time. 688 689 Returns: 690 A callable that can be used to return the appropriately formatted time expression. 691 """ 692 693 def _builder(args: t.List): 694 return exp_class( 695 this=seq_get(args, 0), 696 format=Dialect[dialect].format_time( 697 seq_get(args, 1) 698 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 699 ), 700 ) 701 702 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
705def time_format( 706 dialect: DialectType = None, 707) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 708 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 709 """ 710 Returns the time format for a given expression, unless it's equivalent 711 to the default time format of the dialect of interest. 712 """ 713 time_format = self.format_time(expression) 714 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 715 716 return _time_format
719def build_date_delta( 720 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 721) -> t.Callable[[t.List], E]: 722 def _builder(args: t.List) -> E: 723 unit_based = len(args) == 3 724 this = args[2] if unit_based else seq_get(args, 0) 725 unit = args[0] if unit_based else exp.Literal.string("DAY") 726 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 727 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 728 729 return _builder
732def build_date_delta_with_interval( 733 expression_class: t.Type[E], 734) -> t.Callable[[t.List], t.Optional[E]]: 735 def _builder(args: t.List) -> t.Optional[E]: 736 if len(args) < 2: 737 return None 738 739 interval = args[1] 740 741 if not isinstance(interval, exp.Interval): 742 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 743 744 expression = interval.this 745 if expression and expression.is_string: 746 expression = exp.Literal.number(expression.this) 747 748 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 749 750 return _builder
762def date_add_interval_sql( 763 data_type: str, kind: str 764) -> t.Callable[[Generator, exp.Expression], str]: 765 def func(self: Generator, expression: exp.Expression) -> str: 766 this = self.sql(expression, "this") 767 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 768 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 769 770 return func
777def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 778 if not expression.expression: 779 from sqlglot.optimizer.annotate_types import annotate_types 780 781 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 782 return self.sql(exp.cast(expression.this, target_type)) 783 if expression.text("expression").lower() in TIMEZONES: 784 return self.sql( 785 exp.AtTimeZone( 786 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 787 zone=expression.expression, 788 ) 789 ) 790 return self.func("TIMESTAMP", expression.this, expression.expression)
831def encode_decode_sql( 832 self: Generator, expression: exp.Expression, name: str, replace: bool = True 833) -> str: 834 charset = expression.args.get("charset") 835 if charset and charset.name.lower() != "utf-8": 836 self.unsupported(f"Expected utf-8 character set, got {charset}.") 837 838 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
851def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 852 cond = expression.this 853 854 if isinstance(expression.this, exp.Distinct): 855 cond = expression.this.expressions[0] 856 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 857 858 return self.func("sum", exp.func("if", cond, 1, 0))
861def trim_sql(self: Generator, expression: exp.Trim) -> str: 862 target = self.sql(expression, "this") 863 trim_type = self.sql(expression, "position") 864 remove_chars = self.sql(expression, "expression") 865 collation = self.sql(expression, "collation") 866 867 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 868 if not remove_chars and not collation: 869 return self.trim_sql(expression) 870 871 trim_type = f"{trim_type} " if trim_type else "" 872 remove_chars = f"{remove_chars} " if remove_chars else "" 873 from_part = "FROM " if trim_type or remove_chars else "" 874 collation = f" COLLATE {collation}" if collation else "" 875 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
896def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 897 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 898 if bad_args: 899 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 900 901 return self.func( 902 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 903 )
906def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 907 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 908 if bad_args: 909 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 910 911 return self.func( 912 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 913 )
916def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 917 names = [] 918 for agg in aggregations: 919 if isinstance(agg, exp.Alias): 920 names.append(agg.alias) 921 else: 922 """ 923 This case corresponds to aggregations without aliases being used as suffixes 924 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 925 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 926 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 927 """ 928 agg_all_unquoted = agg.transform( 929 lambda node: ( 930 exp.Identifier(this=node.name, quoted=False) 931 if isinstance(node, exp.Identifier) 932 else node 933 ) 934 ) 935 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 936 937 return names
977def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 978 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 979 if expression.args.get("count"): 980 self.unsupported(f"Only two arguments are supported in function {name}.") 981 982 return self.func(name, expression.this, expression.expression) 983 984 return _arg_max_or_min_sql
987def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 988 this = expression.this.copy() 989 990 return_type = expression.return_type 991 if return_type.is_type(exp.DataType.Type.DATE): 992 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 993 # can truncate timestamp strings, because some dialects can't cast them to DATE 994 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 995 996 expression.this.replace(exp.cast(this, return_type)) 997 return expression
1000def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1001 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1002 if cast and isinstance(expression, exp.TsOrDsAdd): 1003 expression = ts_or_ds_add_cast(expression) 1004 1005 return self.func( 1006 name, 1007 unit_to_var(expression), 1008 expression.expression, 1009 expression.this, 1010 ) 1011 1012 return _delta_sql
1015def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1016 unit = expression.args.get("unit") 1017 1018 if isinstance(unit, exp.Placeholder): 1019 return unit 1020 if unit: 1021 return exp.Literal.string(unit.name) 1022 return exp.Literal.string(default) if default else None
1033def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1034 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1035 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1036 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1037 1038 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1041def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1042 """Remove table refs from columns in when statements.""" 1043 alias = expression.this.args.get("alias") 1044 1045 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1046 return self.dialect.normalize_identifier(identifier).name if identifier else None 1047 1048 targets = {normalize(expression.this.this)} 1049 1050 if alias: 1051 targets.add(normalize(alias.this)) 1052 1053 for when in expression.expressions: 1054 when.transform( 1055 lambda node: ( 1056 exp.column(node.this) 1057 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1058 else node 1059 ), 1060 copy=False, 1061 ) 1062 1063 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1066def build_json_extract_path( 1067 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1068) -> t.Callable[[t.List], F]: 1069 def _builder(args: t.List) -> F: 1070 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1071 for arg in args[1:]: 1072 if not isinstance(arg, exp.Literal): 1073 # We use the fallback parser because we can't really transpile non-literals safely 1074 return expr_type.from_arg_list(args) 1075 1076 text = arg.name 1077 if is_int(text): 1078 index = int(text) 1079 segments.append( 1080 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1081 ) 1082 else: 1083 segments.append(exp.JSONPathKey(this=text)) 1084 1085 # This is done to avoid failing in the expression validator due to the arg count 1086 del args[2:] 1087 return expr_type( 1088 this=seq_get(args, 0), 1089 expression=exp.JSONPath(expressions=segments), 1090 only_json_types=arrow_req_json_type, 1091 ) 1092 1093 return _builder
1096def json_extract_segments( 1097 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1098) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1099 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1100 path = expression.expression 1101 if not isinstance(path, exp.JSONPath): 1102 return rename_func(name)(self, expression) 1103 1104 segments = [] 1105 for segment in path.expressions: 1106 path = self.sql(segment) 1107 if path: 1108 if isinstance(segment, exp.JSONPathPart) and ( 1109 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1110 ): 1111 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1112 1113 segments.append(path) 1114 1115 if op: 1116 return f" {op} ".join([self.sql(expression.this), *segments]) 1117 return self.func(name, expression.this, *segments) 1118 1119 return _json_extract_segments
1129def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1130 cond = expression.expression 1131 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1132 alias = cond.expressions[0] 1133 cond = cond.this 1134 elif isinstance(cond, exp.Predicate): 1135 alias = "_u" 1136 else: 1137 self.unsupported("Unsupported filter condition") 1138 return "" 1139 1140 unnest = exp.Unnest(expressions=[expression.this]) 1141 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1142 return self.sql(exp.Array(expressions=[filtered]))