sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses, to_bool 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[ 19 exp.DateAdd, 20 exp.DateDiff, 21 exp.DateSub, 22 exp.TsOrDsAdd, 23 exp.TsOrDsDiff, 24] 25DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 26JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 27 28 29if t.TYPE_CHECKING: 30 from sqlglot._typing import B, E, F 31 32 from sqlglot.optimizer.annotate_types import TypeAnnotator 33 34 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 35 36logger = logging.getLogger("sqlglot") 37 38UNESCAPED_SEQUENCES = { 39 "\\a": "\a", 40 "\\b": "\b", 41 "\\f": "\f", 42 "\\n": "\n", 43 "\\r": "\r", 44 "\\t": "\t", 45 "\\v": "\v", 46 "\\\\": "\\", 47} 48 49 50def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 51 return lambda self, e: self._annotate_with_type(e, data_type) 52 53 54class Dialects(str, Enum): 55 """Dialects supported by SQLGLot.""" 56 57 DIALECT = "" 58 59 ATHENA = "athena" 60 BIGQUERY = "bigquery" 61 CLICKHOUSE = "clickhouse" 62 DATABRICKS = "databricks" 63 DORIS = "doris" 64 DRILL = "drill" 65 DUCKDB = "duckdb" 66 HIVE = "hive" 67 MATERIALIZE = "materialize" 68 MYSQL = "mysql" 69 ORACLE = "oracle" 70 POSTGRES = "postgres" 71 PRESTO = "presto" 72 PRQL = "prql" 73 REDSHIFT = "redshift" 74 RISINGWAVE = "risingwave" 75 SNOWFLAKE = "snowflake" 76 SPARK = "spark" 77 SPARK2 = "spark2" 78 SQLITE = "sqlite" 79 STARROCKS = "starrocks" 80 TABLEAU = "tableau" 81 TERADATA = "teradata" 82 TRINO = "trino" 83 TSQL = "tsql" 84 85 86class NormalizationStrategy(str, AutoName): 87 """Specifies the strategy according to which identifiers should be normalized.""" 88 89 LOWERCASE = auto() 90 """Unquoted identifiers are lowercased.""" 91 92 UPPERCASE = auto() 93 """Unquoted identifiers are uppercased.""" 94 95 CASE_SENSITIVE = auto() 96 """Always case-sensitive, regardless of quotes.""" 97 98 CASE_INSENSITIVE = auto() 99 """Always case-insensitive, regardless of quotes.""" 100 101 102class _Dialect(type): 103 classes: t.Dict[str, t.Type[Dialect]] = {} 104 105 def __eq__(cls, other: t.Any) -> bool: 106 if cls is other: 107 return True 108 if isinstance(other, str): 109 return cls is cls.get(other) 110 if isinstance(other, Dialect): 111 return cls is type(other) 112 113 return False 114 115 def __hash__(cls) -> int: 116 return hash(cls.__name__.lower()) 117 118 @classmethod 119 def __getitem__(cls, key: str) -> t.Type[Dialect]: 120 return cls.classes[key] 121 122 @classmethod 123 def get( 124 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 125 ) -> t.Optional[t.Type[Dialect]]: 126 return cls.classes.get(key, default) 127 128 def __new__(cls, clsname, bases, attrs): 129 klass = super().__new__(cls, clsname, bases, attrs) 130 enum = Dialects.__members__.get(clsname.upper()) 131 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 132 133 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 134 klass.FORMAT_TRIE = ( 135 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 136 ) 137 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 138 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 139 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 140 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 141 142 klass.INVERSE_CREATABLE_KIND_MAPPING = { 143 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 144 } 145 146 base = seq_get(bases, 0) 147 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 148 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 149 base_parser = (getattr(base, "parser_class", Parser),) 150 base_generator = (getattr(base, "generator_class", Generator),) 151 152 klass.tokenizer_class = klass.__dict__.get( 153 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 154 ) 155 klass.jsonpath_tokenizer_class = klass.__dict__.get( 156 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 157 ) 158 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 159 klass.generator_class = klass.__dict__.get( 160 "Generator", type("Generator", base_generator, {}) 161 ) 162 163 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 164 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 165 klass.tokenizer_class._IDENTIFIERS.items() 166 )[0] 167 168 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 169 return next( 170 ( 171 (s, e) 172 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 173 if t == token_type 174 ), 175 (None, None), 176 ) 177 178 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 179 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 180 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 181 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 182 183 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 184 klass.UNESCAPED_SEQUENCES = { 185 **UNESCAPED_SEQUENCES, 186 **klass.UNESCAPED_SEQUENCES, 187 } 188 189 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 190 191 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 192 193 if enum not in ("", "bigquery"): 194 klass.generator_class.SELECT_KINDS = () 195 196 if enum not in ("", "athena", "presto", "trino"): 197 klass.generator_class.TRY_SUPPORTED = False 198 klass.generator_class.SUPPORTS_UESCAPE = False 199 200 if enum not in ("", "databricks", "hive", "spark", "spark2"): 201 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 202 for modifier in ("cluster", "distribute", "sort"): 203 modifier_transforms.pop(modifier, None) 204 205 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 206 207 if enum not in ("", "doris", "mysql"): 208 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 209 TokenType.STRAIGHT_JOIN, 210 } 211 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 212 TokenType.STRAIGHT_JOIN, 213 } 214 215 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 216 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 217 TokenType.ANTI, 218 TokenType.SEMI, 219 } 220 221 return klass 222 223 224class Dialect(metaclass=_Dialect): 225 INDEX_OFFSET = 0 226 """The base index offset for arrays.""" 227 228 WEEK_OFFSET = 0 229 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 230 231 UNNEST_COLUMN_ONLY = False 232 """Whether `UNNEST` table aliases are treated as column aliases.""" 233 234 ALIAS_POST_TABLESAMPLE = False 235 """Whether the table alias comes after tablesample.""" 236 237 TABLESAMPLE_SIZE_IS_PERCENT = False 238 """Whether a size in the table sample clause represents percentage.""" 239 240 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 241 """Specifies the strategy according to which identifiers should be normalized.""" 242 243 IDENTIFIERS_CAN_START_WITH_DIGIT = False 244 """Whether an unquoted identifier can start with a digit.""" 245 246 DPIPE_IS_STRING_CONCAT = True 247 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 248 249 STRICT_STRING_CONCAT = False 250 """Whether `CONCAT`'s arguments must be strings.""" 251 252 SUPPORTS_USER_DEFINED_TYPES = True 253 """Whether user-defined data types are supported.""" 254 255 SUPPORTS_SEMI_ANTI_JOIN = True 256 """Whether `SEMI` or `ANTI` joins are supported.""" 257 258 SUPPORTS_COLUMN_JOIN_MARKS = False 259 """Whether the old-style outer join (+) syntax is supported.""" 260 261 COPY_PARAMS_ARE_CSV = True 262 """Separator of COPY statement parameters.""" 263 264 NORMALIZE_FUNCTIONS: bool | str = "upper" 265 """ 266 Determines how function names are going to be normalized. 267 Possible values: 268 "upper" or True: Convert names to uppercase. 269 "lower": Convert names to lowercase. 270 False: Disables function name normalization. 271 """ 272 273 PRESERVE_ORIGINAL_NAMES: bool = False 274 """ 275 Whether the name of the function should be preserved inside the node's metadata, 276 can be useful for roundtripping deprecated vs new functions that share an AST node 277 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 278 """ 279 280 LOG_BASE_FIRST: t.Optional[bool] = True 281 """ 282 Whether the base comes first in the `LOG` function. 283 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 284 """ 285 286 NULL_ORDERING = "nulls_are_small" 287 """ 288 Default `NULL` ordering method to use if not explicitly set. 289 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 290 """ 291 292 TYPED_DIVISION = False 293 """ 294 Whether the behavior of `a / b` depends on the types of `a` and `b`. 295 False means `a / b` is always float division. 296 True means `a / b` is integer division if both `a` and `b` are integers. 297 """ 298 299 SAFE_DIVISION = False 300 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 301 302 CONCAT_COALESCE = False 303 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 304 305 HEX_LOWERCASE = False 306 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 307 308 DATE_FORMAT = "'%Y-%m-%d'" 309 DATEINT_FORMAT = "'%Y%m%d'" 310 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 311 312 TIME_MAPPING: t.Dict[str, str] = {} 313 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 314 315 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 316 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 317 FORMAT_MAPPING: t.Dict[str, str] = {} 318 """ 319 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 320 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 321 """ 322 323 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 324 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 325 326 PSEUDOCOLUMNS: t.Set[str] = set() 327 """ 328 Columns that are auto-generated by the engine corresponding to this dialect. 329 For example, such columns may be excluded from `SELECT *` queries. 330 """ 331 332 PREFER_CTE_ALIAS_COLUMN = False 333 """ 334 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 335 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 336 any projection aliases in the subquery. 337 338 For example, 339 WITH y(c) AS ( 340 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 341 ) SELECT c FROM y; 342 343 will be rewritten as 344 345 WITH y(c) AS ( 346 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 347 ) SELECT c FROM y; 348 """ 349 350 COPY_PARAMS_ARE_CSV = True 351 """ 352 Whether COPY statement parameters are separated by comma or whitespace 353 """ 354 355 FORCE_EARLY_ALIAS_REF_EXPANSION = False 356 """ 357 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 358 359 For example: 360 WITH data AS ( 361 SELECT 362 1 AS id, 363 2 AS my_id 364 ) 365 SELECT 366 id AS my_id 367 FROM 368 data 369 WHERE 370 my_id = 1 371 GROUP BY 372 my_id, 373 HAVING 374 my_id = 1 375 376 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 377 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 378 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 379 - Clickhouse, which will forward the alias across the query i.e it resolves 380 to "WHERE id = 1 GROUP BY id HAVING id = 1" 381 """ 382 383 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 384 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 385 386 SUPPORTS_ORDER_BY_ALL = False 387 """ 388 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 389 """ 390 391 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 392 """ 393 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 394 as the former is of type INT[] vs the latter which is SUPER 395 """ 396 397 SUPPORTS_FIXED_SIZE_ARRAYS = False 398 """ 399 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 400 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 401 be interpreted as a subscript/index operator. 402 """ 403 404 STRICT_JSON_PATH_SYNTAX = True 405 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 406 407 ON_CONDITION_EMPTY_BEFORE_ERROR = True 408 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 409 410 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 411 """Whether ArrayAgg needs to filter NULL values.""" 412 413 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 414 """ 415 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 416 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 417 is cast to x's type to match it instead. 418 """ 419 420 SUPPORTS_VALUES_DEFAULT = True 421 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 422 423 REGEXP_EXTRACT_DEFAULT_GROUP = 0 424 """The default value for the capturing group.""" 425 426 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 427 exp.Except: True, 428 exp.Intersect: True, 429 exp.Union: True, 430 } 431 """ 432 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 433 must be explicitly specified. 434 """ 435 436 CREATABLE_KIND_MAPPING: dict[str, str] = {} 437 """ 438 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 439 equivalent of CREATE SCHEMA is CREATE DATABASE. 440 """ 441 442 # --- Autofilled --- 443 444 tokenizer_class = Tokenizer 445 jsonpath_tokenizer_class = JSONPathTokenizer 446 parser_class = Parser 447 generator_class = Generator 448 449 # A trie of the time_mapping keys 450 TIME_TRIE: t.Dict = {} 451 FORMAT_TRIE: t.Dict = {} 452 453 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 454 INVERSE_TIME_TRIE: t.Dict = {} 455 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 456 INVERSE_FORMAT_TRIE: t.Dict = {} 457 458 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 459 460 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 461 462 # Delimiters for string literals and identifiers 463 QUOTE_START = "'" 464 QUOTE_END = "'" 465 IDENTIFIER_START = '"' 466 IDENTIFIER_END = '"' 467 468 # Delimiters for bit, hex, byte and unicode literals 469 BIT_START: t.Optional[str] = None 470 BIT_END: t.Optional[str] = None 471 HEX_START: t.Optional[str] = None 472 HEX_END: t.Optional[str] = None 473 BYTE_START: t.Optional[str] = None 474 BYTE_END: t.Optional[str] = None 475 UNICODE_START: t.Optional[str] = None 476 UNICODE_END: t.Optional[str] = None 477 478 DATE_PART_MAPPING = { 479 "Y": "YEAR", 480 "YY": "YEAR", 481 "YYY": "YEAR", 482 "YYYY": "YEAR", 483 "YR": "YEAR", 484 "YEARS": "YEAR", 485 "YRS": "YEAR", 486 "MM": "MONTH", 487 "MON": "MONTH", 488 "MONS": "MONTH", 489 "MONTHS": "MONTH", 490 "D": "DAY", 491 "DD": "DAY", 492 "DAYS": "DAY", 493 "DAYOFMONTH": "DAY", 494 "DAY OF WEEK": "DAYOFWEEK", 495 "WEEKDAY": "DAYOFWEEK", 496 "DOW": "DAYOFWEEK", 497 "DW": "DAYOFWEEK", 498 "WEEKDAY_ISO": "DAYOFWEEKISO", 499 "DOW_ISO": "DAYOFWEEKISO", 500 "DW_ISO": "DAYOFWEEKISO", 501 "DAY OF YEAR": "DAYOFYEAR", 502 "DOY": "DAYOFYEAR", 503 "DY": "DAYOFYEAR", 504 "W": "WEEK", 505 "WK": "WEEK", 506 "WEEKOFYEAR": "WEEK", 507 "WOY": "WEEK", 508 "WY": "WEEK", 509 "WEEK_ISO": "WEEKISO", 510 "WEEKOFYEARISO": "WEEKISO", 511 "WEEKOFYEAR_ISO": "WEEKISO", 512 "Q": "QUARTER", 513 "QTR": "QUARTER", 514 "QTRS": "QUARTER", 515 "QUARTERS": "QUARTER", 516 "H": "HOUR", 517 "HH": "HOUR", 518 "HR": "HOUR", 519 "HOURS": "HOUR", 520 "HRS": "HOUR", 521 "M": "MINUTE", 522 "MI": "MINUTE", 523 "MIN": "MINUTE", 524 "MINUTES": "MINUTE", 525 "MINS": "MINUTE", 526 "S": "SECOND", 527 "SEC": "SECOND", 528 "SECONDS": "SECOND", 529 "SECS": "SECOND", 530 "MS": "MILLISECOND", 531 "MSEC": "MILLISECOND", 532 "MSECS": "MILLISECOND", 533 "MSECOND": "MILLISECOND", 534 "MSECONDS": "MILLISECOND", 535 "MILLISEC": "MILLISECOND", 536 "MILLISECS": "MILLISECOND", 537 "MILLISECON": "MILLISECOND", 538 "MILLISECONDS": "MILLISECOND", 539 "US": "MICROSECOND", 540 "USEC": "MICROSECOND", 541 "USECS": "MICROSECOND", 542 "MICROSEC": "MICROSECOND", 543 "MICROSECS": "MICROSECOND", 544 "USECOND": "MICROSECOND", 545 "USECONDS": "MICROSECOND", 546 "MICROSECONDS": "MICROSECOND", 547 "NS": "NANOSECOND", 548 "NSEC": "NANOSECOND", 549 "NANOSEC": "NANOSECOND", 550 "NSECOND": "NANOSECOND", 551 "NSECONDS": "NANOSECOND", 552 "NANOSECS": "NANOSECOND", 553 "EPOCH_SECOND": "EPOCH", 554 "EPOCH_SECONDS": "EPOCH", 555 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 556 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 557 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 558 "TZH": "TIMEZONE_HOUR", 559 "TZM": "TIMEZONE_MINUTE", 560 "DEC": "DECADE", 561 "DECS": "DECADE", 562 "DECADES": "DECADE", 563 "MIL": "MILLENIUM", 564 "MILS": "MILLENIUM", 565 "MILLENIA": "MILLENIUM", 566 "C": "CENTURY", 567 "CENT": "CENTURY", 568 "CENTS": "CENTURY", 569 "CENTURIES": "CENTURY", 570 } 571 572 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 573 exp.DataType.Type.BIGINT: { 574 exp.ApproxDistinct, 575 exp.ArraySize, 576 exp.Length, 577 }, 578 exp.DataType.Type.BOOLEAN: { 579 exp.Between, 580 exp.Boolean, 581 exp.In, 582 exp.RegexpLike, 583 }, 584 exp.DataType.Type.DATE: { 585 exp.CurrentDate, 586 exp.Date, 587 exp.DateFromParts, 588 exp.DateStrToDate, 589 exp.DiToDate, 590 exp.StrToDate, 591 exp.TimeStrToDate, 592 exp.TsOrDsToDate, 593 }, 594 exp.DataType.Type.DATETIME: { 595 exp.CurrentDatetime, 596 exp.Datetime, 597 exp.DatetimeAdd, 598 exp.DatetimeSub, 599 }, 600 exp.DataType.Type.DOUBLE: { 601 exp.ApproxQuantile, 602 exp.Avg, 603 exp.Exp, 604 exp.Ln, 605 exp.Log, 606 exp.Pow, 607 exp.Quantile, 608 exp.Round, 609 exp.SafeDivide, 610 exp.Sqrt, 611 exp.Stddev, 612 exp.StddevPop, 613 exp.StddevSamp, 614 exp.ToDouble, 615 exp.Variance, 616 exp.VariancePop, 617 }, 618 exp.DataType.Type.INT: { 619 exp.Ceil, 620 exp.DatetimeDiff, 621 exp.DateDiff, 622 exp.TimestampDiff, 623 exp.TimeDiff, 624 exp.DateToDi, 625 exp.Levenshtein, 626 exp.Sign, 627 exp.StrPosition, 628 exp.TsOrDiToDi, 629 }, 630 exp.DataType.Type.JSON: { 631 exp.ParseJSON, 632 }, 633 exp.DataType.Type.TIME: { 634 exp.Time, 635 }, 636 exp.DataType.Type.TIMESTAMP: { 637 exp.CurrentTime, 638 exp.CurrentTimestamp, 639 exp.StrToTime, 640 exp.TimeAdd, 641 exp.TimeStrToTime, 642 exp.TimeSub, 643 exp.TimestampAdd, 644 exp.TimestampSub, 645 exp.UnixToTime, 646 }, 647 exp.DataType.Type.TINYINT: { 648 exp.Day, 649 exp.Month, 650 exp.Week, 651 exp.Year, 652 exp.Quarter, 653 }, 654 exp.DataType.Type.VARCHAR: { 655 exp.ArrayConcat, 656 exp.Concat, 657 exp.ConcatWs, 658 exp.DateToDateStr, 659 exp.GroupConcat, 660 exp.Initcap, 661 exp.Lower, 662 exp.Substring, 663 exp.String, 664 exp.TimeToStr, 665 exp.TimeToTimeStr, 666 exp.Trim, 667 exp.TsOrDsToDateStr, 668 exp.UnixToStr, 669 exp.UnixToTimeStr, 670 exp.Upper, 671 }, 672 } 673 674 ANNOTATORS: AnnotatorsType = { 675 **{ 676 expr_type: lambda self, e: self._annotate_unary(e) 677 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 678 }, 679 **{ 680 expr_type: lambda self, e: self._annotate_binary(e) 681 for expr_type in subclasses(exp.__name__, exp.Binary) 682 }, 683 **{ 684 expr_type: _annotate_with_type_lambda(data_type) 685 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 686 for expr_type in expressions 687 }, 688 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 689 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 690 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 691 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 692 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 693 exp.Bracket: lambda self, e: self._annotate_bracket(e), 694 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 695 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 696 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 697 exp.Count: lambda self, e: self._annotate_with_type( 698 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 699 ), 700 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 701 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 702 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 703 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 704 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 705 exp.Div: lambda self, e: self._annotate_div(e), 706 exp.Dot: lambda self, e: self._annotate_dot(e), 707 exp.Explode: lambda self, e: self._annotate_explode(e), 708 exp.Extract: lambda self, e: self._annotate_extract(e), 709 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 710 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 711 e, exp.DataType.build("ARRAY<DATE>") 712 ), 713 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 714 e, exp.DataType.build("ARRAY<TIMESTAMP>") 715 ), 716 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 717 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 718 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 719 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 720 exp.Literal: lambda self, e: self._annotate_literal(e), 721 exp.Map: lambda self, e: self._annotate_map(e), 722 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 723 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 724 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 725 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 726 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 727 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 728 exp.Struct: lambda self, e: self._annotate_struct(e), 729 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 730 exp.Timestamp: lambda self, e: self._annotate_with_type( 731 e, 732 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 733 ), 734 exp.ToMap: lambda self, e: self._annotate_to_map(e), 735 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 736 exp.Unnest: lambda self, e: self._annotate_unnest(e), 737 exp.VarMap: lambda self, e: self._annotate_map(e), 738 } 739 740 @classmethod 741 def get_or_raise(cls, dialect: DialectType) -> Dialect: 742 """ 743 Look up a dialect in the global dialect registry and return it if it exists. 744 745 Args: 746 dialect: The target dialect. If this is a string, it can be optionally followed by 747 additional key-value pairs that are separated by commas and are used to specify 748 dialect settings, such as whether the dialect's identifiers are case-sensitive. 749 750 Example: 751 >>> dialect = dialect_class = get_or_raise("duckdb") 752 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 753 754 Returns: 755 The corresponding Dialect instance. 756 """ 757 758 if not dialect: 759 return cls() 760 if isinstance(dialect, _Dialect): 761 return dialect() 762 if isinstance(dialect, Dialect): 763 return dialect 764 if isinstance(dialect, str): 765 try: 766 dialect_name, *kv_strings = dialect.split(",") 767 kv_pairs = (kv.split("=") for kv in kv_strings) 768 kwargs = {} 769 for pair in kv_pairs: 770 key = pair[0].strip() 771 value: t.Union[bool | str | None] = None 772 773 if len(pair) == 1: 774 # Default initialize standalone settings to True 775 value = True 776 elif len(pair) == 2: 777 value = pair[1].strip() 778 779 kwargs[key] = to_bool(value) 780 781 except ValueError: 782 raise ValueError( 783 f"Invalid dialect format: '{dialect}'. " 784 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 785 ) 786 787 result = cls.get(dialect_name.strip()) 788 if not result: 789 from difflib import get_close_matches 790 791 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 792 if similar: 793 similar = f" Did you mean {similar}?" 794 795 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 796 797 return result(**kwargs) 798 799 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 800 801 @classmethod 802 def format_time( 803 cls, expression: t.Optional[str | exp.Expression] 804 ) -> t.Optional[exp.Expression]: 805 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 806 if isinstance(expression, str): 807 return exp.Literal.string( 808 # the time formats are quoted 809 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 810 ) 811 812 if expression and expression.is_string: 813 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 814 815 return expression 816 817 def __init__(self, **kwargs) -> None: 818 normalization_strategy = kwargs.pop("normalization_strategy", None) 819 820 if normalization_strategy is None: 821 self.normalization_strategy = self.NORMALIZATION_STRATEGY 822 else: 823 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 824 825 self.settings = kwargs 826 827 def __eq__(self, other: t.Any) -> bool: 828 # Does not currently take dialect state into account 829 return type(self) == other 830 831 def __hash__(self) -> int: 832 # Does not currently take dialect state into account 833 return hash(type(self)) 834 835 def normalize_identifier(self, expression: E) -> E: 836 """ 837 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 838 839 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 840 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 841 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 842 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 843 844 There are also dialects like Spark, which are case-insensitive even when quotes are 845 present, and dialects like MySQL, whose resolution rules match those employed by the 846 underlying operating system, for example they may always be case-sensitive in Linux. 847 848 Finally, the normalization behavior of some engines can even be controlled through flags, 849 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 850 851 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 852 that it can analyze queries in the optimizer and successfully capture their semantics. 853 """ 854 if ( 855 isinstance(expression, exp.Identifier) 856 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 857 and ( 858 not expression.quoted 859 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 860 ) 861 ): 862 expression.set( 863 "this", 864 ( 865 expression.this.upper() 866 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 867 else expression.this.lower() 868 ), 869 ) 870 871 return expression 872 873 def case_sensitive(self, text: str) -> bool: 874 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 875 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 876 return False 877 878 unsafe = ( 879 str.islower 880 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 881 else str.isupper 882 ) 883 return any(unsafe(char) for char in text) 884 885 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 886 """Checks if text can be identified given an identify option. 887 888 Args: 889 text: The text to check. 890 identify: 891 `"always"` or `True`: Always returns `True`. 892 `"safe"`: Only returns `True` if the identifier is case-insensitive. 893 894 Returns: 895 Whether the given text can be identified. 896 """ 897 if identify is True or identify == "always": 898 return True 899 900 if identify == "safe": 901 return not self.case_sensitive(text) 902 903 return False 904 905 def quote_identifier(self, expression: E, identify: bool = True) -> E: 906 """ 907 Adds quotes to a given identifier. 908 909 Args: 910 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 911 identify: If set to `False`, the quotes will only be added if the identifier is deemed 912 "unsafe", with respect to its characters and this dialect's normalization strategy. 913 """ 914 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 915 name = expression.this 916 expression.set( 917 "quoted", 918 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 919 ) 920 921 return expression 922 923 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 924 if isinstance(path, exp.Literal): 925 path_text = path.name 926 if path.is_number: 927 path_text = f"[{path_text}]" 928 try: 929 return parse_json_path(path_text, self) 930 except ParseError as e: 931 if self.STRICT_JSON_PATH_SYNTAX: 932 logger.warning(f"Invalid JSON path syntax. {str(e)}") 933 934 return path 935 936 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 937 return self.parser(**opts).parse(self.tokenize(sql), sql) 938 939 def parse_into( 940 self, expression_type: exp.IntoType, sql: str, **opts 941 ) -> t.List[t.Optional[exp.Expression]]: 942 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 943 944 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 945 return self.generator(**opts).generate(expression, copy=copy) 946 947 def transpile(self, sql: str, **opts) -> t.List[str]: 948 return [ 949 self.generate(expression, copy=False, **opts) if expression else "" 950 for expression in self.parse(sql) 951 ] 952 953 def tokenize(self, sql: str) -> t.List[Token]: 954 return self.tokenizer.tokenize(sql) 955 956 @property 957 def tokenizer(self) -> Tokenizer: 958 return self.tokenizer_class(dialect=self) 959 960 @property 961 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 962 return self.jsonpath_tokenizer_class(dialect=self) 963 964 def parser(self, **opts) -> Parser: 965 return self.parser_class(dialect=self, **opts) 966 967 def generator(self, **opts) -> Generator: 968 return self.generator_class(dialect=self, **opts) 969 970 971DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 972 973 974def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 975 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 976 977 978@unsupported_args("accuracy") 979def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 980 return self.func("APPROX_COUNT_DISTINCT", expression.this) 981 982 983def if_sql( 984 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 985) -> t.Callable[[Generator, exp.If], str]: 986 def _if_sql(self: Generator, expression: exp.If) -> str: 987 return self.func( 988 name, 989 expression.this, 990 expression.args.get("true"), 991 expression.args.get("false") or false_value, 992 ) 993 994 return _if_sql 995 996 997def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 998 this = expression.this 999 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1000 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1001 1002 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 1003 1004 1005def inline_array_sql(self: Generator, expression: exp.Array) -> str: 1006 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 1007 1008 1009def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 1010 elem = seq_get(expression.expressions, 0) 1011 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 1012 return self.func("ARRAY", elem) 1013 return inline_array_sql(self, expression) 1014 1015 1016def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1017 return self.like_sql( 1018 exp.Like( 1019 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1020 ) 1021 ) 1022 1023 1024def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1025 zone = self.sql(expression, "this") 1026 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1027 1028 1029def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1030 if expression.args.get("recursive"): 1031 self.unsupported("Recursive CTEs are unsupported") 1032 expression.args["recursive"] = False 1033 return self.with_sql(expression) 1034 1035 1036def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide, if_sql: str = "IF") -> str: 1037 n = self.sql(expression, "this") 1038 d = self.sql(expression, "expression") 1039 return f"{if_sql}(({d}) <> 0, ({n}) / ({d}), NULL)" 1040 1041 1042def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1043 self.unsupported("TABLESAMPLE unsupported") 1044 return self.sql(expression.this) 1045 1046 1047def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1048 self.unsupported("PIVOT unsupported") 1049 return "" 1050 1051 1052def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1053 return self.cast_sql(expression) 1054 1055 1056def no_comment_column_constraint_sql( 1057 self: Generator, expression: exp.CommentColumnConstraint 1058) -> str: 1059 self.unsupported("CommentColumnConstraint unsupported") 1060 return "" 1061 1062 1063def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1064 self.unsupported("MAP_FROM_ENTRIES unsupported") 1065 return "" 1066 1067 1068def property_sql(self: Generator, expression: exp.Property) -> str: 1069 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1070 1071 1072def str_position_sql( 1073 self: Generator, 1074 expression: exp.StrPosition, 1075 generate_instance: bool = False, 1076 str_position_func_name: str = "STRPOS", 1077) -> str: 1078 this = self.sql(expression, "this") 1079 substr = self.sql(expression, "substr") 1080 position = self.sql(expression, "position") 1081 instance = expression.args.get("instance") if generate_instance else None 1082 position_offset = "" 1083 1084 if position: 1085 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1086 this = self.func("SUBSTR", this, position) 1087 position_offset = f" + {position} - 1" 1088 1089 return self.func(str_position_func_name, this, substr, instance) + position_offset 1090 1091 1092def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1093 return ( 1094 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1095 ) 1096 1097 1098def var_map_sql( 1099 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1100) -> str: 1101 keys = expression.args["keys"] 1102 values = expression.args["values"] 1103 1104 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1105 self.unsupported("Cannot convert array columns into map.") 1106 return self.func(map_func_name, keys, values) 1107 1108 args = [] 1109 for key, value in zip(keys.expressions, values.expressions): 1110 args.append(self.sql(key)) 1111 args.append(self.sql(value)) 1112 1113 return self.func(map_func_name, *args) 1114 1115 1116def build_formatted_time( 1117 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1118) -> t.Callable[[t.List], E]: 1119 """Helper used for time expressions. 1120 1121 Args: 1122 exp_class: the expression class to instantiate. 1123 dialect: target sql dialect. 1124 default: the default format, True being time. 1125 1126 Returns: 1127 A callable that can be used to return the appropriately formatted time expression. 1128 """ 1129 1130 def _builder(args: t.List): 1131 return exp_class( 1132 this=seq_get(args, 0), 1133 format=Dialect[dialect].format_time( 1134 seq_get(args, 1) 1135 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1136 ), 1137 ) 1138 1139 return _builder 1140 1141 1142def time_format( 1143 dialect: DialectType = None, 1144) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1145 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1146 """ 1147 Returns the time format for a given expression, unless it's equivalent 1148 to the default time format of the dialect of interest. 1149 """ 1150 time_format = self.format_time(expression) 1151 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1152 1153 return _time_format 1154 1155 1156def build_date_delta( 1157 exp_class: t.Type[E], 1158 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1159 default_unit: t.Optional[str] = "DAY", 1160) -> t.Callable[[t.List], E]: 1161 def _builder(args: t.List) -> E: 1162 unit_based = len(args) == 3 1163 this = args[2] if unit_based else seq_get(args, 0) 1164 unit = None 1165 if unit_based or default_unit: 1166 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1167 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1168 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1169 1170 return _builder 1171 1172 1173def build_date_delta_with_interval( 1174 expression_class: t.Type[E], 1175) -> t.Callable[[t.List], t.Optional[E]]: 1176 def _builder(args: t.List) -> t.Optional[E]: 1177 if len(args) < 2: 1178 return None 1179 1180 interval = args[1] 1181 1182 if not isinstance(interval, exp.Interval): 1183 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1184 1185 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1186 1187 return _builder 1188 1189 1190def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1191 unit = seq_get(args, 0) 1192 this = seq_get(args, 1) 1193 1194 if isinstance(this, exp.Cast) and this.is_type("date"): 1195 return exp.DateTrunc(unit=unit, this=this) 1196 return exp.TimestampTrunc(this=this, unit=unit) 1197 1198 1199def date_add_interval_sql( 1200 data_type: str, kind: str 1201) -> t.Callable[[Generator, exp.Expression], str]: 1202 def func(self: Generator, expression: exp.Expression) -> str: 1203 this = self.sql(expression, "this") 1204 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1205 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1206 1207 return func 1208 1209 1210def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1211 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1212 args = [unit_to_str(expression), expression.this] 1213 if zone: 1214 args.append(expression.args.get("zone")) 1215 return self.func("DATE_TRUNC", *args) 1216 1217 return _timestamptrunc_sql 1218 1219 1220def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1221 zone = expression.args.get("zone") 1222 if not zone: 1223 from sqlglot.optimizer.annotate_types import annotate_types 1224 1225 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1226 return self.sql(exp.cast(expression.this, target_type)) 1227 if zone.name.lower() in TIMEZONES: 1228 return self.sql( 1229 exp.AtTimeZone( 1230 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1231 zone=zone, 1232 ) 1233 ) 1234 return self.func("TIMESTAMP", expression.this, zone) 1235 1236 1237def no_time_sql(self: Generator, expression: exp.Time) -> str: 1238 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1239 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1240 expr = exp.cast( 1241 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1242 ) 1243 return self.sql(expr) 1244 1245 1246def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1247 this = expression.this 1248 expr = expression.expression 1249 1250 if expr.name.lower() in TIMEZONES: 1251 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1252 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1253 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1254 return self.sql(this) 1255 1256 this = exp.cast(this, exp.DataType.Type.DATE) 1257 expr = exp.cast(expr, exp.DataType.Type.TIME) 1258 1259 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1260 1261 1262def locate_to_strposition(args: t.List) -> exp.Expression: 1263 return exp.StrPosition( 1264 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1265 ) 1266 1267 1268def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1269 return self.func( 1270 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1271 ) 1272 1273 1274def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1275 return self.sql( 1276 exp.Substring( 1277 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1278 ) 1279 ) 1280 1281 1282def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1283 return self.sql( 1284 exp.Substring( 1285 this=expression.this, 1286 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1287 ) 1288 ) 1289 1290 1291def timestrtotime_sql( 1292 self: Generator, 1293 expression: exp.TimeStrToTime, 1294 include_precision: bool = False, 1295) -> str: 1296 datatype = exp.DataType.build( 1297 exp.DataType.Type.TIMESTAMPTZ 1298 if expression.args.get("zone") 1299 else exp.DataType.Type.TIMESTAMP 1300 ) 1301 1302 if isinstance(expression.this, exp.Literal) and include_precision: 1303 precision = subsecond_precision(expression.this.name) 1304 if precision > 0: 1305 datatype = exp.DataType.build( 1306 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1307 ) 1308 1309 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1310 1311 1312def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1313 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1314 1315 1316# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1317def encode_decode_sql( 1318 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1319) -> str: 1320 charset = expression.args.get("charset") 1321 if charset and charset.name.lower() != "utf-8": 1322 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1323 1324 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1325 1326 1327def min_or_least(self: Generator, expression: exp.Min) -> str: 1328 name = "LEAST" if expression.expressions else "MIN" 1329 return rename_func(name)(self, expression) 1330 1331 1332def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1333 name = "GREATEST" if expression.expressions else "MAX" 1334 return rename_func(name)(self, expression) 1335 1336 1337def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1338 cond = expression.this 1339 1340 if isinstance(expression.this, exp.Distinct): 1341 cond = expression.this.expressions[0] 1342 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1343 1344 return self.func("sum", exp.func("if", cond, 1, 0)) 1345 1346 1347def trim_sql(self: Generator, expression: exp.Trim) -> str: 1348 target = self.sql(expression, "this") 1349 trim_type = self.sql(expression, "position") 1350 remove_chars = self.sql(expression, "expression") 1351 collation = self.sql(expression, "collation") 1352 1353 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1354 if not remove_chars: 1355 return self.trim_sql(expression) 1356 1357 trim_type = f"{trim_type} " if trim_type else "" 1358 remove_chars = f"{remove_chars} " if remove_chars else "" 1359 from_part = "FROM " if trim_type or remove_chars else "" 1360 collation = f" COLLATE {collation}" if collation else "" 1361 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1362 1363 1364def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1365 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1366 1367 1368def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1369 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1370 1371 1372def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1373 delim, *rest_args = expression.expressions 1374 return self.sql( 1375 reduce( 1376 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1377 rest_args, 1378 ) 1379 ) 1380 1381 1382@unsupported_args("position", "occurrence", "parameters") 1383def regexp_extract_sql( 1384 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1385) -> str: 1386 group = expression.args.get("group") 1387 1388 # Do not render group if it's the default value for this dialect 1389 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1390 group = None 1391 1392 return self.func(expression.sql_name(), expression.this, expression.expression, group) 1393 1394 1395@unsupported_args("position", "occurrence", "modifiers") 1396def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1397 return self.func( 1398 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1399 ) 1400 1401 1402def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1403 names = [] 1404 for agg in aggregations: 1405 if isinstance(agg, exp.Alias): 1406 names.append(agg.alias) 1407 else: 1408 """ 1409 This case corresponds to aggregations without aliases being used as suffixes 1410 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1411 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1412 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1413 """ 1414 agg_all_unquoted = agg.transform( 1415 lambda node: ( 1416 exp.Identifier(this=node.name, quoted=False) 1417 if isinstance(node, exp.Identifier) 1418 else node 1419 ) 1420 ) 1421 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1422 1423 return names 1424 1425 1426def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1427 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1428 1429 1430# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1431def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1432 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1433 1434 1435def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1436 return self.func("MAX", expression.this) 1437 1438 1439def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1440 a = self.sql(expression.left) 1441 b = self.sql(expression.right) 1442 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1443 1444 1445def is_parse_json(expression: exp.Expression) -> bool: 1446 return isinstance(expression, exp.ParseJSON) or ( 1447 isinstance(expression, exp.Cast) and expression.is_type("json") 1448 ) 1449 1450 1451def isnull_to_is_null(args: t.List) -> exp.Expression: 1452 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1453 1454 1455def generatedasidentitycolumnconstraint_sql( 1456 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1457) -> str: 1458 start = self.sql(expression, "start") or "1" 1459 increment = self.sql(expression, "increment") or "1" 1460 return f"IDENTITY({start}, {increment})" 1461 1462 1463def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1464 @unsupported_args("count") 1465 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1466 return self.func(name, expression.this, expression.expression) 1467 1468 return _arg_max_or_min_sql 1469 1470 1471def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1472 this = expression.this.copy() 1473 1474 return_type = expression.return_type 1475 if return_type.is_type(exp.DataType.Type.DATE): 1476 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1477 # can truncate timestamp strings, because some dialects can't cast them to DATE 1478 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1479 1480 expression.this.replace(exp.cast(this, return_type)) 1481 return expression 1482 1483 1484def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1485 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1486 if cast and isinstance(expression, exp.TsOrDsAdd): 1487 expression = ts_or_ds_add_cast(expression) 1488 1489 return self.func( 1490 name, 1491 unit_to_var(expression), 1492 expression.expression, 1493 expression.this, 1494 ) 1495 1496 return _delta_sql 1497 1498 1499def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1500 unit = expression.args.get("unit") 1501 1502 if isinstance(unit, exp.Placeholder): 1503 return unit 1504 if unit: 1505 return exp.Literal.string(unit.name) 1506 return exp.Literal.string(default) if default else None 1507 1508 1509def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1510 unit = expression.args.get("unit") 1511 1512 if isinstance(unit, (exp.Var, exp.Placeholder)): 1513 return unit 1514 return exp.Var(this=default) if default else None 1515 1516 1517@t.overload 1518def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1519 pass 1520 1521 1522@t.overload 1523def map_date_part( 1524 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1525) -> t.Optional[exp.Expression]: 1526 pass 1527 1528 1529def map_date_part(part, dialect: DialectType = Dialect): 1530 mapped = ( 1531 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1532 ) 1533 return exp.var(mapped) if mapped else part 1534 1535 1536def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1537 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1538 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1539 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1540 1541 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1542 1543 1544def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1545 """Remove table refs from columns in when statements.""" 1546 alias = expression.this.args.get("alias") 1547 1548 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1549 return self.dialect.normalize_identifier(identifier).name if identifier else None 1550 1551 targets = {normalize(expression.this.this)} 1552 1553 if alias: 1554 targets.add(normalize(alias.this)) 1555 1556 for when in expression.args["whens"].expressions: 1557 # only remove the target names from the THEN clause 1558 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1559 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1560 then = when.args.get("then") 1561 if then: 1562 then.transform( 1563 lambda node: ( 1564 exp.column(node.this) 1565 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1566 else node 1567 ), 1568 copy=False, 1569 ) 1570 1571 return self.merge_sql(expression) 1572 1573 1574def build_json_extract_path( 1575 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1576) -> t.Callable[[t.List], F]: 1577 def _builder(args: t.List) -> F: 1578 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1579 for arg in args[1:]: 1580 if not isinstance(arg, exp.Literal): 1581 # We use the fallback parser because we can't really transpile non-literals safely 1582 return expr_type.from_arg_list(args) 1583 1584 text = arg.name 1585 if is_int(text): 1586 index = int(text) 1587 segments.append( 1588 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1589 ) 1590 else: 1591 segments.append(exp.JSONPathKey(this=text)) 1592 1593 # This is done to avoid failing in the expression validator due to the arg count 1594 del args[2:] 1595 return expr_type( 1596 this=seq_get(args, 0), 1597 expression=exp.JSONPath(expressions=segments), 1598 only_json_types=arrow_req_json_type, 1599 ) 1600 1601 return _builder 1602 1603 1604def json_extract_segments( 1605 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1606) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1607 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1608 path = expression.expression 1609 if not isinstance(path, exp.JSONPath): 1610 return rename_func(name)(self, expression) 1611 1612 escape = path.args.get("escape") 1613 1614 segments = [] 1615 for segment in path.expressions: 1616 path = self.sql(segment) 1617 if path: 1618 if isinstance(segment, exp.JSONPathPart) and ( 1619 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1620 ): 1621 if escape: 1622 path = self.escape_str(path) 1623 1624 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1625 1626 segments.append(path) 1627 1628 if op: 1629 return f" {op} ".join([self.sql(expression.this), *segments]) 1630 return self.func(name, expression.this, *segments) 1631 1632 return _json_extract_segments 1633 1634 1635def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1636 if isinstance(expression.this, exp.JSONPathWildcard): 1637 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1638 1639 return expression.name 1640 1641 1642def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1643 cond = expression.expression 1644 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1645 alias = cond.expressions[0] 1646 cond = cond.this 1647 elif isinstance(cond, exp.Predicate): 1648 alias = "_u" 1649 else: 1650 self.unsupported("Unsupported filter condition") 1651 return "" 1652 1653 unnest = exp.Unnest(expressions=[expression.this]) 1654 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1655 return self.sql(exp.Array(expressions=[filtered])) 1656 1657 1658def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1659 return self.func( 1660 "TO_NUMBER", 1661 expression.this, 1662 expression.args.get("format"), 1663 expression.args.get("nlsparam"), 1664 ) 1665 1666 1667def build_default_decimal_type( 1668 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1669) -> t.Callable[[exp.DataType], exp.DataType]: 1670 def _builder(dtype: exp.DataType) -> exp.DataType: 1671 if dtype.expressions or precision is None: 1672 return dtype 1673 1674 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1675 return exp.DataType.build(f"DECIMAL({params})") 1676 1677 return _builder 1678 1679 1680def build_timestamp_from_parts(args: t.List) -> exp.Func: 1681 if len(args) == 2: 1682 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1683 # so we parse this into Anonymous for now instead of introducing complexity 1684 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1685 1686 return exp.TimestampFromParts.from_arg_list(args) 1687 1688 1689def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1690 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1691 1692 1693def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1694 start = expression.args.get("start") 1695 end = expression.args.get("end") 1696 step = expression.args.get("step") 1697 1698 if isinstance(start, exp.Cast): 1699 target_type = start.to 1700 elif isinstance(end, exp.Cast): 1701 target_type = end.to 1702 else: 1703 target_type = None 1704 1705 if start and end and target_type and target_type.is_type("date", "timestamp"): 1706 if isinstance(start, exp.Cast) and target_type is start.to: 1707 end = exp.cast(end, target_type) 1708 else: 1709 start = exp.cast(start, target_type) 1710 1711 return self.func("SEQUENCE", start, end, step) 1712 1713 1714def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1715 def _builder(args: t.List, dialect: Dialect) -> E: 1716 return expr_type( 1717 this=seq_get(args, 0), 1718 expression=seq_get(args, 1), 1719 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1720 parameters=seq_get(args, 3), 1721 ) 1722 1723 return _builder 1724 1725 1726def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1727 if isinstance(expression.this, exp.Explode): 1728 return self.sql( 1729 exp.Join( 1730 this=exp.Unnest( 1731 expressions=[expression.this.this], 1732 alias=expression.args.get("alias"), 1733 offset=isinstance(expression.this, exp.Posexplode), 1734 ), 1735 kind="cross", 1736 ) 1737 ) 1738 return self.lateral_sql(expression) 1739 1740 1741def timestampdiff_sql(self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff) -> str: 1742 return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) 1743 1744 1745def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1746 args = [] 1747 for unit, value in expression.args.items(): 1748 if isinstance(value, exp.Kwarg): 1749 value = value.expression 1750 1751 args.append(f"{value} {unit}") 1752 1753 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"
55class Dialects(str, Enum): 56 """Dialects supported by SQLGLot.""" 57 58 DIALECT = "" 59 60 ATHENA = "athena" 61 BIGQUERY = "bigquery" 62 CLICKHOUSE = "clickhouse" 63 DATABRICKS = "databricks" 64 DORIS = "doris" 65 DRILL = "drill" 66 DUCKDB = "duckdb" 67 HIVE = "hive" 68 MATERIALIZE = "materialize" 69 MYSQL = "mysql" 70 ORACLE = "oracle" 71 POSTGRES = "postgres" 72 PRESTO = "presto" 73 PRQL = "prql" 74 REDSHIFT = "redshift" 75 RISINGWAVE = "risingwave" 76 SNOWFLAKE = "snowflake" 77 SPARK = "spark" 78 SPARK2 = "spark2" 79 SQLITE = "sqlite" 80 STARROCKS = "starrocks" 81 TABLEAU = "tableau" 82 TERADATA = "teradata" 83 TRINO = "trino" 84 TSQL = "tsql"
Dialects supported by SQLGLot.
87class NormalizationStrategy(str, AutoName): 88 """Specifies the strategy according to which identifiers should be normalized.""" 89 90 LOWERCASE = auto() 91 """Unquoted identifiers are lowercased.""" 92 93 UPPERCASE = auto() 94 """Unquoted identifiers are uppercased.""" 95 96 CASE_SENSITIVE = auto() 97 """Always case-sensitive, regardless of quotes.""" 98 99 CASE_INSENSITIVE = auto() 100 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
225class Dialect(metaclass=_Dialect): 226 INDEX_OFFSET = 0 227 """The base index offset for arrays.""" 228 229 WEEK_OFFSET = 0 230 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 231 232 UNNEST_COLUMN_ONLY = False 233 """Whether `UNNEST` table aliases are treated as column aliases.""" 234 235 ALIAS_POST_TABLESAMPLE = False 236 """Whether the table alias comes after tablesample.""" 237 238 TABLESAMPLE_SIZE_IS_PERCENT = False 239 """Whether a size in the table sample clause represents percentage.""" 240 241 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 242 """Specifies the strategy according to which identifiers should be normalized.""" 243 244 IDENTIFIERS_CAN_START_WITH_DIGIT = False 245 """Whether an unquoted identifier can start with a digit.""" 246 247 DPIPE_IS_STRING_CONCAT = True 248 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 249 250 STRICT_STRING_CONCAT = False 251 """Whether `CONCAT`'s arguments must be strings.""" 252 253 SUPPORTS_USER_DEFINED_TYPES = True 254 """Whether user-defined data types are supported.""" 255 256 SUPPORTS_SEMI_ANTI_JOIN = True 257 """Whether `SEMI` or `ANTI` joins are supported.""" 258 259 SUPPORTS_COLUMN_JOIN_MARKS = False 260 """Whether the old-style outer join (+) syntax is supported.""" 261 262 COPY_PARAMS_ARE_CSV = True 263 """Separator of COPY statement parameters.""" 264 265 NORMALIZE_FUNCTIONS: bool | str = "upper" 266 """ 267 Determines how function names are going to be normalized. 268 Possible values: 269 "upper" or True: Convert names to uppercase. 270 "lower": Convert names to lowercase. 271 False: Disables function name normalization. 272 """ 273 274 PRESERVE_ORIGINAL_NAMES: bool = False 275 """ 276 Whether the name of the function should be preserved inside the node's metadata, 277 can be useful for roundtripping deprecated vs new functions that share an AST node 278 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 279 """ 280 281 LOG_BASE_FIRST: t.Optional[bool] = True 282 """ 283 Whether the base comes first in the `LOG` function. 284 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 285 """ 286 287 NULL_ORDERING = "nulls_are_small" 288 """ 289 Default `NULL` ordering method to use if not explicitly set. 290 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 291 """ 292 293 TYPED_DIVISION = False 294 """ 295 Whether the behavior of `a / b` depends on the types of `a` and `b`. 296 False means `a / b` is always float division. 297 True means `a / b` is integer division if both `a` and `b` are integers. 298 """ 299 300 SAFE_DIVISION = False 301 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 302 303 CONCAT_COALESCE = False 304 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 305 306 HEX_LOWERCASE = False 307 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 308 309 DATE_FORMAT = "'%Y-%m-%d'" 310 DATEINT_FORMAT = "'%Y%m%d'" 311 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 312 313 TIME_MAPPING: t.Dict[str, str] = {} 314 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 315 316 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 317 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 318 FORMAT_MAPPING: t.Dict[str, str] = {} 319 """ 320 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 321 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 322 """ 323 324 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 325 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 326 327 PSEUDOCOLUMNS: t.Set[str] = set() 328 """ 329 Columns that are auto-generated by the engine corresponding to this dialect. 330 For example, such columns may be excluded from `SELECT *` queries. 331 """ 332 333 PREFER_CTE_ALIAS_COLUMN = False 334 """ 335 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 336 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 337 any projection aliases in the subquery. 338 339 For example, 340 WITH y(c) AS ( 341 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 342 ) SELECT c FROM y; 343 344 will be rewritten as 345 346 WITH y(c) AS ( 347 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 348 ) SELECT c FROM y; 349 """ 350 351 COPY_PARAMS_ARE_CSV = True 352 """ 353 Whether COPY statement parameters are separated by comma or whitespace 354 """ 355 356 FORCE_EARLY_ALIAS_REF_EXPANSION = False 357 """ 358 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 359 360 For example: 361 WITH data AS ( 362 SELECT 363 1 AS id, 364 2 AS my_id 365 ) 366 SELECT 367 id AS my_id 368 FROM 369 data 370 WHERE 371 my_id = 1 372 GROUP BY 373 my_id, 374 HAVING 375 my_id = 1 376 377 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 378 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 379 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 380 - Clickhouse, which will forward the alias across the query i.e it resolves 381 to "WHERE id = 1 GROUP BY id HAVING id = 1" 382 """ 383 384 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 385 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 386 387 SUPPORTS_ORDER_BY_ALL = False 388 """ 389 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 390 """ 391 392 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 393 """ 394 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 395 as the former is of type INT[] vs the latter which is SUPER 396 """ 397 398 SUPPORTS_FIXED_SIZE_ARRAYS = False 399 """ 400 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 401 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 402 be interpreted as a subscript/index operator. 403 """ 404 405 STRICT_JSON_PATH_SYNTAX = True 406 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 407 408 ON_CONDITION_EMPTY_BEFORE_ERROR = True 409 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 410 411 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 412 """Whether ArrayAgg needs to filter NULL values.""" 413 414 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 415 """ 416 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 417 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 418 is cast to x's type to match it instead. 419 """ 420 421 SUPPORTS_VALUES_DEFAULT = True 422 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 423 424 REGEXP_EXTRACT_DEFAULT_GROUP = 0 425 """The default value for the capturing group.""" 426 427 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 428 exp.Except: True, 429 exp.Intersect: True, 430 exp.Union: True, 431 } 432 """ 433 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 434 must be explicitly specified. 435 """ 436 437 CREATABLE_KIND_MAPPING: dict[str, str] = {} 438 """ 439 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 440 equivalent of CREATE SCHEMA is CREATE DATABASE. 441 """ 442 443 # --- Autofilled --- 444 445 tokenizer_class = Tokenizer 446 jsonpath_tokenizer_class = JSONPathTokenizer 447 parser_class = Parser 448 generator_class = Generator 449 450 # A trie of the time_mapping keys 451 TIME_TRIE: t.Dict = {} 452 FORMAT_TRIE: t.Dict = {} 453 454 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 455 INVERSE_TIME_TRIE: t.Dict = {} 456 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 457 INVERSE_FORMAT_TRIE: t.Dict = {} 458 459 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 460 461 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 462 463 # Delimiters for string literals and identifiers 464 QUOTE_START = "'" 465 QUOTE_END = "'" 466 IDENTIFIER_START = '"' 467 IDENTIFIER_END = '"' 468 469 # Delimiters for bit, hex, byte and unicode literals 470 BIT_START: t.Optional[str] = None 471 BIT_END: t.Optional[str] = None 472 HEX_START: t.Optional[str] = None 473 HEX_END: t.Optional[str] = None 474 BYTE_START: t.Optional[str] = None 475 BYTE_END: t.Optional[str] = None 476 UNICODE_START: t.Optional[str] = None 477 UNICODE_END: t.Optional[str] = None 478 479 DATE_PART_MAPPING = { 480 "Y": "YEAR", 481 "YY": "YEAR", 482 "YYY": "YEAR", 483 "YYYY": "YEAR", 484 "YR": "YEAR", 485 "YEARS": "YEAR", 486 "YRS": "YEAR", 487 "MM": "MONTH", 488 "MON": "MONTH", 489 "MONS": "MONTH", 490 "MONTHS": "MONTH", 491 "D": "DAY", 492 "DD": "DAY", 493 "DAYS": "DAY", 494 "DAYOFMONTH": "DAY", 495 "DAY OF WEEK": "DAYOFWEEK", 496 "WEEKDAY": "DAYOFWEEK", 497 "DOW": "DAYOFWEEK", 498 "DW": "DAYOFWEEK", 499 "WEEKDAY_ISO": "DAYOFWEEKISO", 500 "DOW_ISO": "DAYOFWEEKISO", 501 "DW_ISO": "DAYOFWEEKISO", 502 "DAY OF YEAR": "DAYOFYEAR", 503 "DOY": "DAYOFYEAR", 504 "DY": "DAYOFYEAR", 505 "W": "WEEK", 506 "WK": "WEEK", 507 "WEEKOFYEAR": "WEEK", 508 "WOY": "WEEK", 509 "WY": "WEEK", 510 "WEEK_ISO": "WEEKISO", 511 "WEEKOFYEARISO": "WEEKISO", 512 "WEEKOFYEAR_ISO": "WEEKISO", 513 "Q": "QUARTER", 514 "QTR": "QUARTER", 515 "QTRS": "QUARTER", 516 "QUARTERS": "QUARTER", 517 "H": "HOUR", 518 "HH": "HOUR", 519 "HR": "HOUR", 520 "HOURS": "HOUR", 521 "HRS": "HOUR", 522 "M": "MINUTE", 523 "MI": "MINUTE", 524 "MIN": "MINUTE", 525 "MINUTES": "MINUTE", 526 "MINS": "MINUTE", 527 "S": "SECOND", 528 "SEC": "SECOND", 529 "SECONDS": "SECOND", 530 "SECS": "SECOND", 531 "MS": "MILLISECOND", 532 "MSEC": "MILLISECOND", 533 "MSECS": "MILLISECOND", 534 "MSECOND": "MILLISECOND", 535 "MSECONDS": "MILLISECOND", 536 "MILLISEC": "MILLISECOND", 537 "MILLISECS": "MILLISECOND", 538 "MILLISECON": "MILLISECOND", 539 "MILLISECONDS": "MILLISECOND", 540 "US": "MICROSECOND", 541 "USEC": "MICROSECOND", 542 "USECS": "MICROSECOND", 543 "MICROSEC": "MICROSECOND", 544 "MICROSECS": "MICROSECOND", 545 "USECOND": "MICROSECOND", 546 "USECONDS": "MICROSECOND", 547 "MICROSECONDS": "MICROSECOND", 548 "NS": "NANOSECOND", 549 "NSEC": "NANOSECOND", 550 "NANOSEC": "NANOSECOND", 551 "NSECOND": "NANOSECOND", 552 "NSECONDS": "NANOSECOND", 553 "NANOSECS": "NANOSECOND", 554 "EPOCH_SECOND": "EPOCH", 555 "EPOCH_SECONDS": "EPOCH", 556 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 557 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 558 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 559 "TZH": "TIMEZONE_HOUR", 560 "TZM": "TIMEZONE_MINUTE", 561 "DEC": "DECADE", 562 "DECS": "DECADE", 563 "DECADES": "DECADE", 564 "MIL": "MILLENIUM", 565 "MILS": "MILLENIUM", 566 "MILLENIA": "MILLENIUM", 567 "C": "CENTURY", 568 "CENT": "CENTURY", 569 "CENTS": "CENTURY", 570 "CENTURIES": "CENTURY", 571 } 572 573 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 574 exp.DataType.Type.BIGINT: { 575 exp.ApproxDistinct, 576 exp.ArraySize, 577 exp.Length, 578 }, 579 exp.DataType.Type.BOOLEAN: { 580 exp.Between, 581 exp.Boolean, 582 exp.In, 583 exp.RegexpLike, 584 }, 585 exp.DataType.Type.DATE: { 586 exp.CurrentDate, 587 exp.Date, 588 exp.DateFromParts, 589 exp.DateStrToDate, 590 exp.DiToDate, 591 exp.StrToDate, 592 exp.TimeStrToDate, 593 exp.TsOrDsToDate, 594 }, 595 exp.DataType.Type.DATETIME: { 596 exp.CurrentDatetime, 597 exp.Datetime, 598 exp.DatetimeAdd, 599 exp.DatetimeSub, 600 }, 601 exp.DataType.Type.DOUBLE: { 602 exp.ApproxQuantile, 603 exp.Avg, 604 exp.Exp, 605 exp.Ln, 606 exp.Log, 607 exp.Pow, 608 exp.Quantile, 609 exp.Round, 610 exp.SafeDivide, 611 exp.Sqrt, 612 exp.Stddev, 613 exp.StddevPop, 614 exp.StddevSamp, 615 exp.ToDouble, 616 exp.Variance, 617 exp.VariancePop, 618 }, 619 exp.DataType.Type.INT: { 620 exp.Ceil, 621 exp.DatetimeDiff, 622 exp.DateDiff, 623 exp.TimestampDiff, 624 exp.TimeDiff, 625 exp.DateToDi, 626 exp.Levenshtein, 627 exp.Sign, 628 exp.StrPosition, 629 exp.TsOrDiToDi, 630 }, 631 exp.DataType.Type.JSON: { 632 exp.ParseJSON, 633 }, 634 exp.DataType.Type.TIME: { 635 exp.Time, 636 }, 637 exp.DataType.Type.TIMESTAMP: { 638 exp.CurrentTime, 639 exp.CurrentTimestamp, 640 exp.StrToTime, 641 exp.TimeAdd, 642 exp.TimeStrToTime, 643 exp.TimeSub, 644 exp.TimestampAdd, 645 exp.TimestampSub, 646 exp.UnixToTime, 647 }, 648 exp.DataType.Type.TINYINT: { 649 exp.Day, 650 exp.Month, 651 exp.Week, 652 exp.Year, 653 exp.Quarter, 654 }, 655 exp.DataType.Type.VARCHAR: { 656 exp.ArrayConcat, 657 exp.Concat, 658 exp.ConcatWs, 659 exp.DateToDateStr, 660 exp.GroupConcat, 661 exp.Initcap, 662 exp.Lower, 663 exp.Substring, 664 exp.String, 665 exp.TimeToStr, 666 exp.TimeToTimeStr, 667 exp.Trim, 668 exp.TsOrDsToDateStr, 669 exp.UnixToStr, 670 exp.UnixToTimeStr, 671 exp.Upper, 672 }, 673 } 674 675 ANNOTATORS: AnnotatorsType = { 676 **{ 677 expr_type: lambda self, e: self._annotate_unary(e) 678 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 679 }, 680 **{ 681 expr_type: lambda self, e: self._annotate_binary(e) 682 for expr_type in subclasses(exp.__name__, exp.Binary) 683 }, 684 **{ 685 expr_type: _annotate_with_type_lambda(data_type) 686 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 687 for expr_type in expressions 688 }, 689 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 690 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 691 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 692 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 693 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 694 exp.Bracket: lambda self, e: self._annotate_bracket(e), 695 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 696 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 697 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 698 exp.Count: lambda self, e: self._annotate_with_type( 699 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 700 ), 701 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 702 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 703 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 704 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 705 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 706 exp.Div: lambda self, e: self._annotate_div(e), 707 exp.Dot: lambda self, e: self._annotate_dot(e), 708 exp.Explode: lambda self, e: self._annotate_explode(e), 709 exp.Extract: lambda self, e: self._annotate_extract(e), 710 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 711 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 712 e, exp.DataType.build("ARRAY<DATE>") 713 ), 714 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 715 e, exp.DataType.build("ARRAY<TIMESTAMP>") 716 ), 717 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 718 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 719 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 720 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 721 exp.Literal: lambda self, e: self._annotate_literal(e), 722 exp.Map: lambda self, e: self._annotate_map(e), 723 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 724 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 725 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 726 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 727 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 728 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 729 exp.Struct: lambda self, e: self._annotate_struct(e), 730 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 731 exp.Timestamp: lambda self, e: self._annotate_with_type( 732 e, 733 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 734 ), 735 exp.ToMap: lambda self, e: self._annotate_to_map(e), 736 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 737 exp.Unnest: lambda self, e: self._annotate_unnest(e), 738 exp.VarMap: lambda self, e: self._annotate_map(e), 739 } 740 741 @classmethod 742 def get_or_raise(cls, dialect: DialectType) -> Dialect: 743 """ 744 Look up a dialect in the global dialect registry and return it if it exists. 745 746 Args: 747 dialect: The target dialect. If this is a string, it can be optionally followed by 748 additional key-value pairs that are separated by commas and are used to specify 749 dialect settings, such as whether the dialect's identifiers are case-sensitive. 750 751 Example: 752 >>> dialect = dialect_class = get_or_raise("duckdb") 753 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 754 755 Returns: 756 The corresponding Dialect instance. 757 """ 758 759 if not dialect: 760 return cls() 761 if isinstance(dialect, _Dialect): 762 return dialect() 763 if isinstance(dialect, Dialect): 764 return dialect 765 if isinstance(dialect, str): 766 try: 767 dialect_name, *kv_strings = dialect.split(",") 768 kv_pairs = (kv.split("=") for kv in kv_strings) 769 kwargs = {} 770 for pair in kv_pairs: 771 key = pair[0].strip() 772 value: t.Union[bool | str | None] = None 773 774 if len(pair) == 1: 775 # Default initialize standalone settings to True 776 value = True 777 elif len(pair) == 2: 778 value = pair[1].strip() 779 780 kwargs[key] = to_bool(value) 781 782 except ValueError: 783 raise ValueError( 784 f"Invalid dialect format: '{dialect}'. " 785 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 786 ) 787 788 result = cls.get(dialect_name.strip()) 789 if not result: 790 from difflib import get_close_matches 791 792 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 793 if similar: 794 similar = f" Did you mean {similar}?" 795 796 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 797 798 return result(**kwargs) 799 800 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 801 802 @classmethod 803 def format_time( 804 cls, expression: t.Optional[str | exp.Expression] 805 ) -> t.Optional[exp.Expression]: 806 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 807 if isinstance(expression, str): 808 return exp.Literal.string( 809 # the time formats are quoted 810 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 811 ) 812 813 if expression and expression.is_string: 814 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 815 816 return expression 817 818 def __init__(self, **kwargs) -> None: 819 normalization_strategy = kwargs.pop("normalization_strategy", None) 820 821 if normalization_strategy is None: 822 self.normalization_strategy = self.NORMALIZATION_STRATEGY 823 else: 824 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 825 826 self.settings = kwargs 827 828 def __eq__(self, other: t.Any) -> bool: 829 # Does not currently take dialect state into account 830 return type(self) == other 831 832 def __hash__(self) -> int: 833 # Does not currently take dialect state into account 834 return hash(type(self)) 835 836 def normalize_identifier(self, expression: E) -> E: 837 """ 838 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 839 840 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 841 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 842 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 843 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 844 845 There are also dialects like Spark, which are case-insensitive even when quotes are 846 present, and dialects like MySQL, whose resolution rules match those employed by the 847 underlying operating system, for example they may always be case-sensitive in Linux. 848 849 Finally, the normalization behavior of some engines can even be controlled through flags, 850 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 851 852 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 853 that it can analyze queries in the optimizer and successfully capture their semantics. 854 """ 855 if ( 856 isinstance(expression, exp.Identifier) 857 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 858 and ( 859 not expression.quoted 860 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 861 ) 862 ): 863 expression.set( 864 "this", 865 ( 866 expression.this.upper() 867 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 868 else expression.this.lower() 869 ), 870 ) 871 872 return expression 873 874 def case_sensitive(self, text: str) -> bool: 875 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 876 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 877 return False 878 879 unsafe = ( 880 str.islower 881 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 882 else str.isupper 883 ) 884 return any(unsafe(char) for char in text) 885 886 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 887 """Checks if text can be identified given an identify option. 888 889 Args: 890 text: The text to check. 891 identify: 892 `"always"` or `True`: Always returns `True`. 893 `"safe"`: Only returns `True` if the identifier is case-insensitive. 894 895 Returns: 896 Whether the given text can be identified. 897 """ 898 if identify is True or identify == "always": 899 return True 900 901 if identify == "safe": 902 return not self.case_sensitive(text) 903 904 return False 905 906 def quote_identifier(self, expression: E, identify: bool = True) -> E: 907 """ 908 Adds quotes to a given identifier. 909 910 Args: 911 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 912 identify: If set to `False`, the quotes will only be added if the identifier is deemed 913 "unsafe", with respect to its characters and this dialect's normalization strategy. 914 """ 915 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 916 name = expression.this 917 expression.set( 918 "quoted", 919 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 920 ) 921 922 return expression 923 924 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 925 if isinstance(path, exp.Literal): 926 path_text = path.name 927 if path.is_number: 928 path_text = f"[{path_text}]" 929 try: 930 return parse_json_path(path_text, self) 931 except ParseError as e: 932 if self.STRICT_JSON_PATH_SYNTAX: 933 logger.warning(f"Invalid JSON path syntax. {str(e)}") 934 935 return path 936 937 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 938 return self.parser(**opts).parse(self.tokenize(sql), sql) 939 940 def parse_into( 941 self, expression_type: exp.IntoType, sql: str, **opts 942 ) -> t.List[t.Optional[exp.Expression]]: 943 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 944 945 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 946 return self.generator(**opts).generate(expression, copy=copy) 947 948 def transpile(self, sql: str, **opts) -> t.List[str]: 949 return [ 950 self.generate(expression, copy=False, **opts) if expression else "" 951 for expression in self.parse(sql) 952 ] 953 954 def tokenize(self, sql: str) -> t.List[Token]: 955 return self.tokenizer.tokenize(sql) 956 957 @property 958 def tokenizer(self) -> Tokenizer: 959 return self.tokenizer_class(dialect=self) 960 961 @property 962 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 963 return self.jsonpath_tokenizer_class(dialect=self) 964 965 def parser(self, **opts) -> Parser: 966 return self.parser_class(dialect=self, **opts) 967 968 def generator(self, **opts) -> Generator: 969 return self.generator_class(dialect=self, **opts)
818 def __init__(self, **kwargs) -> None: 819 normalization_strategy = kwargs.pop("normalization_strategy", None) 820 821 if normalization_strategy is None: 822 self.normalization_strategy = self.NORMALIZATION_STRATEGY 823 else: 824 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 825 826 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the name of the function should be preserved inside the node's metadata, can be useful for roundtripping deprecated vs new functions that share an AST node e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal is cast to x's type to match it instead.
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
741 @classmethod 742 def get_or_raise(cls, dialect: DialectType) -> Dialect: 743 """ 744 Look up a dialect in the global dialect registry and return it if it exists. 745 746 Args: 747 dialect: The target dialect. If this is a string, it can be optionally followed by 748 additional key-value pairs that are separated by commas and are used to specify 749 dialect settings, such as whether the dialect's identifiers are case-sensitive. 750 751 Example: 752 >>> dialect = dialect_class = get_or_raise("duckdb") 753 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 754 755 Returns: 756 The corresponding Dialect instance. 757 """ 758 759 if not dialect: 760 return cls() 761 if isinstance(dialect, _Dialect): 762 return dialect() 763 if isinstance(dialect, Dialect): 764 return dialect 765 if isinstance(dialect, str): 766 try: 767 dialect_name, *kv_strings = dialect.split(",") 768 kv_pairs = (kv.split("=") for kv in kv_strings) 769 kwargs = {} 770 for pair in kv_pairs: 771 key = pair[0].strip() 772 value: t.Union[bool | str | None] = None 773 774 if len(pair) == 1: 775 # Default initialize standalone settings to True 776 value = True 777 elif len(pair) == 2: 778 value = pair[1].strip() 779 780 kwargs[key] = to_bool(value) 781 782 except ValueError: 783 raise ValueError( 784 f"Invalid dialect format: '{dialect}'. " 785 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 786 ) 787 788 result = cls.get(dialect_name.strip()) 789 if not result: 790 from difflib import get_close_matches 791 792 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 793 if similar: 794 similar = f" Did you mean {similar}?" 795 796 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 797 798 return result(**kwargs) 799 800 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
802 @classmethod 803 def format_time( 804 cls, expression: t.Optional[str | exp.Expression] 805 ) -> t.Optional[exp.Expression]: 806 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 807 if isinstance(expression, str): 808 return exp.Literal.string( 809 # the time formats are quoted 810 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 811 ) 812 813 if expression and expression.is_string: 814 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 815 816 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
836 def normalize_identifier(self, expression: E) -> E: 837 """ 838 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 839 840 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 841 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 842 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 843 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 844 845 There are also dialects like Spark, which are case-insensitive even when quotes are 846 present, and dialects like MySQL, whose resolution rules match those employed by the 847 underlying operating system, for example they may always be case-sensitive in Linux. 848 849 Finally, the normalization behavior of some engines can even be controlled through flags, 850 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 851 852 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 853 that it can analyze queries in the optimizer and successfully capture their semantics. 854 """ 855 if ( 856 isinstance(expression, exp.Identifier) 857 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 858 and ( 859 not expression.quoted 860 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 861 ) 862 ): 863 expression.set( 864 "this", 865 ( 866 expression.this.upper() 867 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 868 else expression.this.lower() 869 ), 870 ) 871 872 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
874 def case_sensitive(self, text: str) -> bool: 875 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 876 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 877 return False 878 879 unsafe = ( 880 str.islower 881 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 882 else str.isupper 883 ) 884 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
886 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 887 """Checks if text can be identified given an identify option. 888 889 Args: 890 text: The text to check. 891 identify: 892 `"always"` or `True`: Always returns `True`. 893 `"safe"`: Only returns `True` if the identifier is case-insensitive. 894 895 Returns: 896 Whether the given text can be identified. 897 """ 898 if identify is True or identify == "always": 899 return True 900 901 if identify == "safe": 902 return not self.case_sensitive(text) 903 904 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
906 def quote_identifier(self, expression: E, identify: bool = True) -> E: 907 """ 908 Adds quotes to a given identifier. 909 910 Args: 911 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 912 identify: If set to `False`, the quotes will only be added if the identifier is deemed 913 "unsafe", with respect to its characters and this dialect's normalization strategy. 914 """ 915 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 916 name = expression.this 917 expression.set( 918 "quoted", 919 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 920 ) 921 922 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
924 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 925 if isinstance(path, exp.Literal): 926 path_text = path.name 927 if path.is_number: 928 path_text = f"[{path_text}]" 929 try: 930 return parse_json_path(path_text, self) 931 except ParseError as e: 932 if self.STRICT_JSON_PATH_SYNTAX: 933 logger.warning(f"Invalid JSON path syntax. {str(e)}") 934 935 return path
984def if_sql( 985 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 986) -> t.Callable[[Generator, exp.If], str]: 987 def _if_sql(self: Generator, expression: exp.If) -> str: 988 return self.func( 989 name, 990 expression.this, 991 expression.args.get("true"), 992 expression.args.get("false") or false_value, 993 ) 994 995 return _if_sql
998def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 999 this = expression.this 1000 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1001 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1002 1003 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1073def str_position_sql( 1074 self: Generator, 1075 expression: exp.StrPosition, 1076 generate_instance: bool = False, 1077 str_position_func_name: str = "STRPOS", 1078) -> str: 1079 this = self.sql(expression, "this") 1080 substr = self.sql(expression, "substr") 1081 position = self.sql(expression, "position") 1082 instance = expression.args.get("instance") if generate_instance else None 1083 position_offset = "" 1084 1085 if position: 1086 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1087 this = self.func("SUBSTR", this, position) 1088 position_offset = f" + {position} - 1" 1089 1090 return self.func(str_position_func_name, this, substr, instance) + position_offset
1099def var_map_sql( 1100 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1101) -> str: 1102 keys = expression.args["keys"] 1103 values = expression.args["values"] 1104 1105 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1106 self.unsupported("Cannot convert array columns into map.") 1107 return self.func(map_func_name, keys, values) 1108 1109 args = [] 1110 for key, value in zip(keys.expressions, values.expressions): 1111 args.append(self.sql(key)) 1112 args.append(self.sql(value)) 1113 1114 return self.func(map_func_name, *args)
1117def build_formatted_time( 1118 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1119) -> t.Callable[[t.List], E]: 1120 """Helper used for time expressions. 1121 1122 Args: 1123 exp_class: the expression class to instantiate. 1124 dialect: target sql dialect. 1125 default: the default format, True being time. 1126 1127 Returns: 1128 A callable that can be used to return the appropriately formatted time expression. 1129 """ 1130 1131 def _builder(args: t.List): 1132 return exp_class( 1133 this=seq_get(args, 0), 1134 format=Dialect[dialect].format_time( 1135 seq_get(args, 1) 1136 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1137 ), 1138 ) 1139 1140 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1143def time_format( 1144 dialect: DialectType = None, 1145) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1146 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1147 """ 1148 Returns the time format for a given expression, unless it's equivalent 1149 to the default time format of the dialect of interest. 1150 """ 1151 time_format = self.format_time(expression) 1152 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1153 1154 return _time_format
1157def build_date_delta( 1158 exp_class: t.Type[E], 1159 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1160 default_unit: t.Optional[str] = "DAY", 1161) -> t.Callable[[t.List], E]: 1162 def _builder(args: t.List) -> E: 1163 unit_based = len(args) == 3 1164 this = args[2] if unit_based else seq_get(args, 0) 1165 unit = None 1166 if unit_based or default_unit: 1167 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1168 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1169 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1170 1171 return _builder
1174def build_date_delta_with_interval( 1175 expression_class: t.Type[E], 1176) -> t.Callable[[t.List], t.Optional[E]]: 1177 def _builder(args: t.List) -> t.Optional[E]: 1178 if len(args) < 2: 1179 return None 1180 1181 interval = args[1] 1182 1183 if not isinstance(interval, exp.Interval): 1184 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1185 1186 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1187 1188 return _builder
1191def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1192 unit = seq_get(args, 0) 1193 this = seq_get(args, 1) 1194 1195 if isinstance(this, exp.Cast) and this.is_type("date"): 1196 return exp.DateTrunc(unit=unit, this=this) 1197 return exp.TimestampTrunc(this=this, unit=unit)
1200def date_add_interval_sql( 1201 data_type: str, kind: str 1202) -> t.Callable[[Generator, exp.Expression], str]: 1203 def func(self: Generator, expression: exp.Expression) -> str: 1204 this = self.sql(expression, "this") 1205 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1206 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1207 1208 return func
1211def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1212 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1213 args = [unit_to_str(expression), expression.this] 1214 if zone: 1215 args.append(expression.args.get("zone")) 1216 return self.func("DATE_TRUNC", *args) 1217 1218 return _timestamptrunc_sql
1221def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1222 zone = expression.args.get("zone") 1223 if not zone: 1224 from sqlglot.optimizer.annotate_types import annotate_types 1225 1226 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1227 return self.sql(exp.cast(expression.this, target_type)) 1228 if zone.name.lower() in TIMEZONES: 1229 return self.sql( 1230 exp.AtTimeZone( 1231 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1232 zone=zone, 1233 ) 1234 ) 1235 return self.func("TIMESTAMP", expression.this, zone)
1238def no_time_sql(self: Generator, expression: exp.Time) -> str: 1239 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1240 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1241 expr = exp.cast( 1242 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1243 ) 1244 return self.sql(expr)
1247def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1248 this = expression.this 1249 expr = expression.expression 1250 1251 if expr.name.lower() in TIMEZONES: 1252 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1253 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1254 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1255 return self.sql(this) 1256 1257 this = exp.cast(this, exp.DataType.Type.DATE) 1258 expr = exp.cast(expr, exp.DataType.Type.TIME) 1259 1260 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1292def timestrtotime_sql( 1293 self: Generator, 1294 expression: exp.TimeStrToTime, 1295 include_precision: bool = False, 1296) -> str: 1297 datatype = exp.DataType.build( 1298 exp.DataType.Type.TIMESTAMPTZ 1299 if expression.args.get("zone") 1300 else exp.DataType.Type.TIMESTAMP 1301 ) 1302 1303 if isinstance(expression.this, exp.Literal) and include_precision: 1304 precision = subsecond_precision(expression.this.name) 1305 if precision > 0: 1306 datatype = exp.DataType.build( 1307 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1308 ) 1309 1310 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1318def encode_decode_sql( 1319 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1320) -> str: 1321 charset = expression.args.get("charset") 1322 if charset and charset.name.lower() != "utf-8": 1323 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1324 1325 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1338def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1339 cond = expression.this 1340 1341 if isinstance(expression.this, exp.Distinct): 1342 cond = expression.this.expressions[0] 1343 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1344 1345 return self.func("sum", exp.func("if", cond, 1, 0))
1348def trim_sql(self: Generator, expression: exp.Trim) -> str: 1349 target = self.sql(expression, "this") 1350 trim_type = self.sql(expression, "position") 1351 remove_chars = self.sql(expression, "expression") 1352 collation = self.sql(expression, "collation") 1353 1354 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1355 if not remove_chars: 1356 return self.trim_sql(expression) 1357 1358 trim_type = f"{trim_type} " if trim_type else "" 1359 remove_chars = f"{remove_chars} " if remove_chars else "" 1360 from_part = "FROM " if trim_type or remove_chars else "" 1361 collation = f" COLLATE {collation}" if collation else "" 1362 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1383@unsupported_args("position", "occurrence", "parameters") 1384def regexp_extract_sql( 1385 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1386) -> str: 1387 group = expression.args.get("group") 1388 1389 # Do not render group if it's the default value for this dialect 1390 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1391 group = None 1392 1393 return self.func(expression.sql_name(), expression.this, expression.expression, group)
1403def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1404 names = [] 1405 for agg in aggregations: 1406 if isinstance(agg, exp.Alias): 1407 names.append(agg.alias) 1408 else: 1409 """ 1410 This case corresponds to aggregations without aliases being used as suffixes 1411 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1412 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1413 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1414 """ 1415 agg_all_unquoted = agg.transform( 1416 lambda node: ( 1417 exp.Identifier(this=node.name, quoted=False) 1418 if isinstance(node, exp.Identifier) 1419 else node 1420 ) 1421 ) 1422 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1423 1424 return names
1464def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1465 @unsupported_args("count") 1466 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1467 return self.func(name, expression.this, expression.expression) 1468 1469 return _arg_max_or_min_sql
1472def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1473 this = expression.this.copy() 1474 1475 return_type = expression.return_type 1476 if return_type.is_type(exp.DataType.Type.DATE): 1477 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1478 # can truncate timestamp strings, because some dialects can't cast them to DATE 1479 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1480 1481 expression.this.replace(exp.cast(this, return_type)) 1482 return expression
1485def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1486 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1487 if cast and isinstance(expression, exp.TsOrDsAdd): 1488 expression = ts_or_ds_add_cast(expression) 1489 1490 return self.func( 1491 name, 1492 unit_to_var(expression), 1493 expression.expression, 1494 expression.this, 1495 ) 1496 1497 return _delta_sql
1500def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1501 unit = expression.args.get("unit") 1502 1503 if isinstance(unit, exp.Placeholder): 1504 return unit 1505 if unit: 1506 return exp.Literal.string(unit.name) 1507 return exp.Literal.string(default) if default else None
1537def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1538 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1539 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1540 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1541 1542 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1545def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1546 """Remove table refs from columns in when statements.""" 1547 alias = expression.this.args.get("alias") 1548 1549 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1550 return self.dialect.normalize_identifier(identifier).name if identifier else None 1551 1552 targets = {normalize(expression.this.this)} 1553 1554 if alias: 1555 targets.add(normalize(alias.this)) 1556 1557 for when in expression.args["whens"].expressions: 1558 # only remove the target names from the THEN clause 1559 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1560 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1561 then = when.args.get("then") 1562 if then: 1563 then.transform( 1564 lambda node: ( 1565 exp.column(node.this) 1566 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1567 else node 1568 ), 1569 copy=False, 1570 ) 1571 1572 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1575def build_json_extract_path( 1576 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1577) -> t.Callable[[t.List], F]: 1578 def _builder(args: t.List) -> F: 1579 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1580 for arg in args[1:]: 1581 if not isinstance(arg, exp.Literal): 1582 # We use the fallback parser because we can't really transpile non-literals safely 1583 return expr_type.from_arg_list(args) 1584 1585 text = arg.name 1586 if is_int(text): 1587 index = int(text) 1588 segments.append( 1589 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1590 ) 1591 else: 1592 segments.append(exp.JSONPathKey(this=text)) 1593 1594 # This is done to avoid failing in the expression validator due to the arg count 1595 del args[2:] 1596 return expr_type( 1597 this=seq_get(args, 0), 1598 expression=exp.JSONPath(expressions=segments), 1599 only_json_types=arrow_req_json_type, 1600 ) 1601 1602 return _builder
1605def json_extract_segments( 1606 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1607) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1608 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1609 path = expression.expression 1610 if not isinstance(path, exp.JSONPath): 1611 return rename_func(name)(self, expression) 1612 1613 escape = path.args.get("escape") 1614 1615 segments = [] 1616 for segment in path.expressions: 1617 path = self.sql(segment) 1618 if path: 1619 if isinstance(segment, exp.JSONPathPart) and ( 1620 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1621 ): 1622 if escape: 1623 path = self.escape_str(path) 1624 1625 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1626 1627 segments.append(path) 1628 1629 if op: 1630 return f" {op} ".join([self.sql(expression.this), *segments]) 1631 return self.func(name, expression.this, *segments) 1632 1633 return _json_extract_segments
1643def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1644 cond = expression.expression 1645 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1646 alias = cond.expressions[0] 1647 cond = cond.this 1648 elif isinstance(cond, exp.Predicate): 1649 alias = "_u" 1650 else: 1651 self.unsupported("Unsupported filter condition") 1652 return "" 1653 1654 unnest = exp.Unnest(expressions=[expression.this]) 1655 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1656 return self.sql(exp.Array(expressions=[filtered]))
1668def build_default_decimal_type( 1669 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1670) -> t.Callable[[exp.DataType], exp.DataType]: 1671 def _builder(dtype: exp.DataType) -> exp.DataType: 1672 if dtype.expressions or precision is None: 1673 return dtype 1674 1675 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1676 return exp.DataType.build(f"DECIMAL({params})") 1677 1678 return _builder
1681def build_timestamp_from_parts(args: t.List) -> exp.Func: 1682 if len(args) == 2: 1683 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1684 # so we parse this into Anonymous for now instead of introducing complexity 1685 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1686 1687 return exp.TimestampFromParts.from_arg_list(args)
1694def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1695 start = expression.args.get("start") 1696 end = expression.args.get("end") 1697 step = expression.args.get("step") 1698 1699 if isinstance(start, exp.Cast): 1700 target_type = start.to 1701 elif isinstance(end, exp.Cast): 1702 target_type = end.to 1703 else: 1704 target_type = None 1705 1706 if start and end and target_type and target_type.is_type("date", "timestamp"): 1707 if isinstance(start, exp.Cast) and target_type is start.to: 1708 end = exp.cast(end, target_type) 1709 else: 1710 start = exp.cast(start, target_type) 1711 1712 return self.func("SEQUENCE", start, end, step)
1715def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1716 def _builder(args: t.List, dialect: Dialect) -> E: 1717 return expr_type( 1718 this=seq_get(args, 0), 1719 expression=seq_get(args, 1), 1720 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1721 parameters=seq_get(args, 3), 1722 ) 1723 1724 return _builder
1727def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1728 if isinstance(expression.this, exp.Explode): 1729 return self.sql( 1730 exp.Join( 1731 this=exp.Unnest( 1732 expressions=[expression.this.this], 1733 alias=expression.args.get("alias"), 1734 offset=isinstance(expression.this, exp.Posexplode), 1735 ), 1736 kind="cross", 1737 ) 1738 ) 1739 return self.lateral_sql(expression)
1746def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1747 args = [] 1748 for unit, value in expression.args.items(): 1749 if isinstance(value, exp.Kwarg): 1750 value = value.expression 1751 1752 args.append(f"{value} {unit}") 1753 1754 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"