sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "athena", "presto", "trino"): 191 klass.generator_class.TRY_SUPPORTED = False 192 klass.generator_class.SUPPORTS_UESCAPE = False 193 194 if enum not in ("", "databricks", "hive", "spark", "spark2"): 195 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 196 for modifier in ("cluster", "distribute", "sort"): 197 modifier_transforms.pop(modifier, None) 198 199 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 200 201 if enum not in ("", "doris", "mysql"): 202 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 203 TokenType.STRAIGHT_JOIN, 204 } 205 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 209 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 210 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 211 TokenType.ANTI, 212 TokenType.SEMI, 213 } 214 215 return klass 216 217 218class Dialect(metaclass=_Dialect): 219 INDEX_OFFSET = 0 220 """The base index offset for arrays.""" 221 222 WEEK_OFFSET = 0 223 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 224 225 UNNEST_COLUMN_ONLY = False 226 """Whether `UNNEST` table aliases are treated as column aliases.""" 227 228 ALIAS_POST_TABLESAMPLE = False 229 """Whether the table alias comes after tablesample.""" 230 231 TABLESAMPLE_SIZE_IS_PERCENT = False 232 """Whether a size in the table sample clause represents percentage.""" 233 234 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 235 """Specifies the strategy according to which identifiers should be normalized.""" 236 237 IDENTIFIERS_CAN_START_WITH_DIGIT = False 238 """Whether an unquoted identifier can start with a digit.""" 239 240 DPIPE_IS_STRING_CONCAT = True 241 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 242 243 STRICT_STRING_CONCAT = False 244 """Whether `CONCAT`'s arguments must be strings.""" 245 246 SUPPORTS_USER_DEFINED_TYPES = True 247 """Whether user-defined data types are supported.""" 248 249 SUPPORTS_SEMI_ANTI_JOIN = True 250 """Whether `SEMI` or `ANTI` joins are supported.""" 251 252 SUPPORTS_COLUMN_JOIN_MARKS = False 253 """Whether the old-style outer join (+) syntax is supported.""" 254 255 COPY_PARAMS_ARE_CSV = True 256 """Separator of COPY statement parameters.""" 257 258 NORMALIZE_FUNCTIONS: bool | str = "upper" 259 """ 260 Determines how function names are going to be normalized. 261 Possible values: 262 "upper" or True: Convert names to uppercase. 263 "lower": Convert names to lowercase. 264 False: Disables function name normalization. 265 """ 266 267 PRESERVE_ORIGINAL_NAMES: bool = False 268 """ 269 Whether the name of the function should be preserved inside the node's metadata, 270 can be useful for roundtripping deprecated vs new functions that share an AST node 271 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 272 """ 273 274 LOG_BASE_FIRST: t.Optional[bool] = True 275 """ 276 Whether the base comes first in the `LOG` function. 277 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 278 """ 279 280 NULL_ORDERING = "nulls_are_small" 281 """ 282 Default `NULL` ordering method to use if not explicitly set. 283 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 284 """ 285 286 TYPED_DIVISION = False 287 """ 288 Whether the behavior of `a / b` depends on the types of `a` and `b`. 289 False means `a / b` is always float division. 290 True means `a / b` is integer division if both `a` and `b` are integers. 291 """ 292 293 SAFE_DIVISION = False 294 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 295 296 CONCAT_COALESCE = False 297 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 298 299 HEX_LOWERCASE = False 300 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 301 302 DATE_FORMAT = "'%Y-%m-%d'" 303 DATEINT_FORMAT = "'%Y%m%d'" 304 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 305 306 TIME_MAPPING: t.Dict[str, str] = {} 307 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 308 309 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 310 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 311 FORMAT_MAPPING: t.Dict[str, str] = {} 312 """ 313 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 314 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 315 """ 316 317 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 318 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 319 320 PSEUDOCOLUMNS: t.Set[str] = set() 321 """ 322 Columns that are auto-generated by the engine corresponding to this dialect. 323 For example, such columns may be excluded from `SELECT *` queries. 324 """ 325 326 PREFER_CTE_ALIAS_COLUMN = False 327 """ 328 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 329 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 330 any projection aliases in the subquery. 331 332 For example, 333 WITH y(c) AS ( 334 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 335 ) SELECT c FROM y; 336 337 will be rewritten as 338 339 WITH y(c) AS ( 340 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 341 ) SELECT c FROM y; 342 """ 343 344 COPY_PARAMS_ARE_CSV = True 345 """ 346 Whether COPY statement parameters are separated by comma or whitespace 347 """ 348 349 FORCE_EARLY_ALIAS_REF_EXPANSION = False 350 """ 351 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 352 353 For example: 354 WITH data AS ( 355 SELECT 356 1 AS id, 357 2 AS my_id 358 ) 359 SELECT 360 id AS my_id 361 FROM 362 data 363 WHERE 364 my_id = 1 365 GROUP BY 366 my_id, 367 HAVING 368 my_id = 1 369 370 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 371 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 372 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 373 - Clickhouse, which will forward the alias across the query i.e it resolves 374 to "WHERE id = 1 GROUP BY id HAVING id = 1" 375 """ 376 377 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 378 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 379 380 SUPPORTS_ORDER_BY_ALL = False 381 """ 382 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 383 """ 384 385 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 386 """ 387 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 388 as the former is of type INT[] vs the latter which is SUPER 389 """ 390 391 SUPPORTS_FIXED_SIZE_ARRAYS = False 392 """ 393 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 394 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 395 be interpreted as a subscript/index operator. 396 """ 397 398 STRICT_JSON_PATH_SYNTAX = True 399 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 400 401 ON_CONDITION_EMPTY_BEFORE_ERROR = True 402 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 403 404 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 405 """Whether ArrayAgg needs to filter NULL values.""" 406 407 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 408 """ 409 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 410 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 411 is cast to x's type to match it instead. 412 """ 413 414 REGEXP_EXTRACT_DEFAULT_GROUP = 0 415 """The default value for the capturing group.""" 416 417 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 418 exp.Except: True, 419 exp.Intersect: True, 420 exp.Union: True, 421 } 422 """ 423 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 424 must be explicitly specified. 425 """ 426 427 CREATABLE_KIND_MAPPING: dict[str, str] = {} 428 """ 429 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 430 equivalent of CREATE SCHEMA is CREATE DATABASE. 431 """ 432 433 # --- Autofilled --- 434 435 tokenizer_class = Tokenizer 436 jsonpath_tokenizer_class = JSONPathTokenizer 437 parser_class = Parser 438 generator_class = Generator 439 440 # A trie of the time_mapping keys 441 TIME_TRIE: t.Dict = {} 442 FORMAT_TRIE: t.Dict = {} 443 444 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 445 INVERSE_TIME_TRIE: t.Dict = {} 446 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 447 INVERSE_FORMAT_TRIE: t.Dict = {} 448 449 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 450 451 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 452 453 # Delimiters for string literals and identifiers 454 QUOTE_START = "'" 455 QUOTE_END = "'" 456 IDENTIFIER_START = '"' 457 IDENTIFIER_END = '"' 458 459 # Delimiters for bit, hex, byte and unicode literals 460 BIT_START: t.Optional[str] = None 461 BIT_END: t.Optional[str] = None 462 HEX_START: t.Optional[str] = None 463 HEX_END: t.Optional[str] = None 464 BYTE_START: t.Optional[str] = None 465 BYTE_END: t.Optional[str] = None 466 UNICODE_START: t.Optional[str] = None 467 UNICODE_END: t.Optional[str] = None 468 469 DATE_PART_MAPPING = { 470 "Y": "YEAR", 471 "YY": "YEAR", 472 "YYY": "YEAR", 473 "YYYY": "YEAR", 474 "YR": "YEAR", 475 "YEARS": "YEAR", 476 "YRS": "YEAR", 477 "MM": "MONTH", 478 "MON": "MONTH", 479 "MONS": "MONTH", 480 "MONTHS": "MONTH", 481 "D": "DAY", 482 "DD": "DAY", 483 "DAYS": "DAY", 484 "DAYOFMONTH": "DAY", 485 "DAY OF WEEK": "DAYOFWEEK", 486 "WEEKDAY": "DAYOFWEEK", 487 "DOW": "DAYOFWEEK", 488 "DW": "DAYOFWEEK", 489 "WEEKDAY_ISO": "DAYOFWEEKISO", 490 "DOW_ISO": "DAYOFWEEKISO", 491 "DW_ISO": "DAYOFWEEKISO", 492 "DAY OF YEAR": "DAYOFYEAR", 493 "DOY": "DAYOFYEAR", 494 "DY": "DAYOFYEAR", 495 "W": "WEEK", 496 "WK": "WEEK", 497 "WEEKOFYEAR": "WEEK", 498 "WOY": "WEEK", 499 "WY": "WEEK", 500 "WEEK_ISO": "WEEKISO", 501 "WEEKOFYEARISO": "WEEKISO", 502 "WEEKOFYEAR_ISO": "WEEKISO", 503 "Q": "QUARTER", 504 "QTR": "QUARTER", 505 "QTRS": "QUARTER", 506 "QUARTERS": "QUARTER", 507 "H": "HOUR", 508 "HH": "HOUR", 509 "HR": "HOUR", 510 "HOURS": "HOUR", 511 "HRS": "HOUR", 512 "M": "MINUTE", 513 "MI": "MINUTE", 514 "MIN": "MINUTE", 515 "MINUTES": "MINUTE", 516 "MINS": "MINUTE", 517 "S": "SECOND", 518 "SEC": "SECOND", 519 "SECONDS": "SECOND", 520 "SECS": "SECOND", 521 "MS": "MILLISECOND", 522 "MSEC": "MILLISECOND", 523 "MSECS": "MILLISECOND", 524 "MSECOND": "MILLISECOND", 525 "MSECONDS": "MILLISECOND", 526 "MILLISEC": "MILLISECOND", 527 "MILLISECS": "MILLISECOND", 528 "MILLISECON": "MILLISECOND", 529 "MILLISECONDS": "MILLISECOND", 530 "US": "MICROSECOND", 531 "USEC": "MICROSECOND", 532 "USECS": "MICROSECOND", 533 "MICROSEC": "MICROSECOND", 534 "MICROSECS": "MICROSECOND", 535 "USECOND": "MICROSECOND", 536 "USECONDS": "MICROSECOND", 537 "MICROSECONDS": "MICROSECOND", 538 "NS": "NANOSECOND", 539 "NSEC": "NANOSECOND", 540 "NANOSEC": "NANOSECOND", 541 "NSECOND": "NANOSECOND", 542 "NSECONDS": "NANOSECOND", 543 "NANOSECS": "NANOSECOND", 544 "EPOCH_SECOND": "EPOCH", 545 "EPOCH_SECONDS": "EPOCH", 546 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 547 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 548 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 549 "TZH": "TIMEZONE_HOUR", 550 "TZM": "TIMEZONE_MINUTE", 551 "DEC": "DECADE", 552 "DECS": "DECADE", 553 "DECADES": "DECADE", 554 "MIL": "MILLENIUM", 555 "MILS": "MILLENIUM", 556 "MILLENIA": "MILLENIUM", 557 "C": "CENTURY", 558 "CENT": "CENTURY", 559 "CENTS": "CENTURY", 560 "CENTURIES": "CENTURY", 561 } 562 563 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 564 exp.DataType.Type.BIGINT: { 565 exp.ApproxDistinct, 566 exp.ArraySize, 567 exp.Length, 568 }, 569 exp.DataType.Type.BOOLEAN: { 570 exp.Between, 571 exp.Boolean, 572 exp.In, 573 exp.RegexpLike, 574 }, 575 exp.DataType.Type.DATE: { 576 exp.CurrentDate, 577 exp.Date, 578 exp.DateFromParts, 579 exp.DateStrToDate, 580 exp.DiToDate, 581 exp.StrToDate, 582 exp.TimeStrToDate, 583 exp.TsOrDsToDate, 584 }, 585 exp.DataType.Type.DATETIME: { 586 exp.CurrentDatetime, 587 exp.Datetime, 588 exp.DatetimeAdd, 589 exp.DatetimeSub, 590 }, 591 exp.DataType.Type.DOUBLE: { 592 exp.ApproxQuantile, 593 exp.Avg, 594 exp.Exp, 595 exp.Ln, 596 exp.Log, 597 exp.Pow, 598 exp.Quantile, 599 exp.Round, 600 exp.SafeDivide, 601 exp.Sqrt, 602 exp.Stddev, 603 exp.StddevPop, 604 exp.StddevSamp, 605 exp.ToDouble, 606 exp.Variance, 607 exp.VariancePop, 608 }, 609 exp.DataType.Type.INT: { 610 exp.Ceil, 611 exp.DatetimeDiff, 612 exp.DateDiff, 613 exp.TimestampDiff, 614 exp.TimeDiff, 615 exp.DateToDi, 616 exp.Levenshtein, 617 exp.Sign, 618 exp.StrPosition, 619 exp.TsOrDiToDi, 620 }, 621 exp.DataType.Type.JSON: { 622 exp.ParseJSON, 623 }, 624 exp.DataType.Type.TIME: { 625 exp.Time, 626 }, 627 exp.DataType.Type.TIMESTAMP: { 628 exp.CurrentTime, 629 exp.CurrentTimestamp, 630 exp.StrToTime, 631 exp.TimeAdd, 632 exp.TimeStrToTime, 633 exp.TimeSub, 634 exp.TimestampAdd, 635 exp.TimestampSub, 636 exp.UnixToTime, 637 }, 638 exp.DataType.Type.TINYINT: { 639 exp.Day, 640 exp.Month, 641 exp.Week, 642 exp.Year, 643 exp.Quarter, 644 }, 645 exp.DataType.Type.VARCHAR: { 646 exp.ArrayConcat, 647 exp.Concat, 648 exp.ConcatWs, 649 exp.DateToDateStr, 650 exp.GroupConcat, 651 exp.Initcap, 652 exp.Lower, 653 exp.Substring, 654 exp.String, 655 exp.TimeToStr, 656 exp.TimeToTimeStr, 657 exp.Trim, 658 exp.TsOrDsToDateStr, 659 exp.UnixToStr, 660 exp.UnixToTimeStr, 661 exp.Upper, 662 }, 663 } 664 665 ANNOTATORS: AnnotatorsType = { 666 **{ 667 expr_type: lambda self, e: self._annotate_unary(e) 668 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 669 }, 670 **{ 671 expr_type: lambda self, e: self._annotate_binary(e) 672 for expr_type in subclasses(exp.__name__, exp.Binary) 673 }, 674 **{ 675 expr_type: _annotate_with_type_lambda(data_type) 676 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 677 for expr_type in expressions 678 }, 679 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 680 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 681 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 682 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 683 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 684 exp.Bracket: lambda self, e: self._annotate_bracket(e), 685 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 686 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 687 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 688 exp.Count: lambda self, e: self._annotate_with_type( 689 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 690 ), 691 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 692 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 693 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 694 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 695 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 696 exp.Div: lambda self, e: self._annotate_div(e), 697 exp.Dot: lambda self, e: self._annotate_dot(e), 698 exp.Explode: lambda self, e: self._annotate_explode(e), 699 exp.Extract: lambda self, e: self._annotate_extract(e), 700 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 701 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 702 e, exp.DataType.build("ARRAY<DATE>") 703 ), 704 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 705 e, exp.DataType.build("ARRAY<TIMESTAMP>") 706 ), 707 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 708 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 709 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 710 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 711 exp.Literal: lambda self, e: self._annotate_literal(e), 712 exp.Map: lambda self, e: self._annotate_map(e), 713 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 714 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 715 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 716 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 717 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 718 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 719 exp.Struct: lambda self, e: self._annotate_struct(e), 720 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 721 exp.Timestamp: lambda self, e: self._annotate_with_type( 722 e, 723 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 724 ), 725 exp.ToMap: lambda self, e: self._annotate_to_map(e), 726 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 727 exp.Unnest: lambda self, e: self._annotate_unnest(e), 728 exp.VarMap: lambda self, e: self._annotate_map(e), 729 } 730 731 @classmethod 732 def get_or_raise(cls, dialect: DialectType) -> Dialect: 733 """ 734 Look up a dialect in the global dialect registry and return it if it exists. 735 736 Args: 737 dialect: The target dialect. If this is a string, it can be optionally followed by 738 additional key-value pairs that are separated by commas and are used to specify 739 dialect settings, such as whether the dialect's identifiers are case-sensitive. 740 741 Example: 742 >>> dialect = dialect_class = get_or_raise("duckdb") 743 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 744 745 Returns: 746 The corresponding Dialect instance. 747 """ 748 749 if not dialect: 750 return cls() 751 if isinstance(dialect, _Dialect): 752 return dialect() 753 if isinstance(dialect, Dialect): 754 return dialect 755 if isinstance(dialect, str): 756 try: 757 dialect_name, *kv_strings = dialect.split(",") 758 kv_pairs = (kv.split("=") for kv in kv_strings) 759 kwargs = {} 760 for pair in kv_pairs: 761 key = pair[0].strip() 762 value: t.Union[bool | str | None] = None 763 764 if len(pair) == 1: 765 # Default initialize standalone settings to True 766 value = True 767 elif len(pair) == 2: 768 value = pair[1].strip() 769 770 # Coerce the value to boolean if it matches to the truthy/falsy values below 771 value_lower = value.lower() 772 if value_lower in ("true", "1"): 773 value = True 774 elif value_lower in ("false", "0"): 775 value = False 776 777 kwargs[key] = value 778 779 except ValueError: 780 raise ValueError( 781 f"Invalid dialect format: '{dialect}'. " 782 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 783 ) 784 785 result = cls.get(dialect_name.strip()) 786 if not result: 787 from difflib import get_close_matches 788 789 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 790 if similar: 791 similar = f" Did you mean {similar}?" 792 793 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 794 795 return result(**kwargs) 796 797 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 798 799 @classmethod 800 def format_time( 801 cls, expression: t.Optional[str | exp.Expression] 802 ) -> t.Optional[exp.Expression]: 803 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 804 if isinstance(expression, str): 805 return exp.Literal.string( 806 # the time formats are quoted 807 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 808 ) 809 810 if expression and expression.is_string: 811 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 812 813 return expression 814 815 def __init__(self, **kwargs) -> None: 816 normalization_strategy = kwargs.pop("normalization_strategy", None) 817 818 if normalization_strategy is None: 819 self.normalization_strategy = self.NORMALIZATION_STRATEGY 820 else: 821 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 822 823 self.settings = kwargs 824 825 def __eq__(self, other: t.Any) -> bool: 826 # Does not currently take dialect state into account 827 return type(self) == other 828 829 def __hash__(self) -> int: 830 # Does not currently take dialect state into account 831 return hash(type(self)) 832 833 def normalize_identifier(self, expression: E) -> E: 834 """ 835 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 836 837 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 838 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 839 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 840 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 841 842 There are also dialects like Spark, which are case-insensitive even when quotes are 843 present, and dialects like MySQL, whose resolution rules match those employed by the 844 underlying operating system, for example they may always be case-sensitive in Linux. 845 846 Finally, the normalization behavior of some engines can even be controlled through flags, 847 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 848 849 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 850 that it can analyze queries in the optimizer and successfully capture their semantics. 851 """ 852 if ( 853 isinstance(expression, exp.Identifier) 854 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 855 and ( 856 not expression.quoted 857 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 858 ) 859 ): 860 expression.set( 861 "this", 862 ( 863 expression.this.upper() 864 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 865 else expression.this.lower() 866 ), 867 ) 868 869 return expression 870 871 def case_sensitive(self, text: str) -> bool: 872 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 873 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 874 return False 875 876 unsafe = ( 877 str.islower 878 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 879 else str.isupper 880 ) 881 return any(unsafe(char) for char in text) 882 883 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 884 """Checks if text can be identified given an identify option. 885 886 Args: 887 text: The text to check. 888 identify: 889 `"always"` or `True`: Always returns `True`. 890 `"safe"`: Only returns `True` if the identifier is case-insensitive. 891 892 Returns: 893 Whether the given text can be identified. 894 """ 895 if identify is True or identify == "always": 896 return True 897 898 if identify == "safe": 899 return not self.case_sensitive(text) 900 901 return False 902 903 def quote_identifier(self, expression: E, identify: bool = True) -> E: 904 """ 905 Adds quotes to a given identifier. 906 907 Args: 908 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 909 identify: If set to `False`, the quotes will only be added if the identifier is deemed 910 "unsafe", with respect to its characters and this dialect's normalization strategy. 911 """ 912 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 913 name = expression.this 914 expression.set( 915 "quoted", 916 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 917 ) 918 919 return expression 920 921 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 922 if isinstance(path, exp.Literal): 923 path_text = path.name 924 if path.is_number: 925 path_text = f"[{path_text}]" 926 try: 927 return parse_json_path(path_text, self) 928 except ParseError as e: 929 if self.STRICT_JSON_PATH_SYNTAX: 930 logger.warning(f"Invalid JSON path syntax. {str(e)}") 931 932 return path 933 934 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 935 return self.parser(**opts).parse(self.tokenize(sql), sql) 936 937 def parse_into( 938 self, expression_type: exp.IntoType, sql: str, **opts 939 ) -> t.List[t.Optional[exp.Expression]]: 940 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 941 942 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 943 return self.generator(**opts).generate(expression, copy=copy) 944 945 def transpile(self, sql: str, **opts) -> t.List[str]: 946 return [ 947 self.generate(expression, copy=False, **opts) if expression else "" 948 for expression in self.parse(sql) 949 ] 950 951 def tokenize(self, sql: str) -> t.List[Token]: 952 return self.tokenizer.tokenize(sql) 953 954 @property 955 def tokenizer(self) -> Tokenizer: 956 return self.tokenizer_class(dialect=self) 957 958 @property 959 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 960 return self.jsonpath_tokenizer_class(dialect=self) 961 962 def parser(self, **opts) -> Parser: 963 return self.parser_class(dialect=self, **opts) 964 965 def generator(self, **opts) -> Generator: 966 return self.generator_class(dialect=self, **opts) 967 968 969DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 970 971 972def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 973 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 974 975 976@unsupported_args("accuracy") 977def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 978 return self.func("APPROX_COUNT_DISTINCT", expression.this) 979 980 981def if_sql( 982 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 983) -> t.Callable[[Generator, exp.If], str]: 984 def _if_sql(self: Generator, expression: exp.If) -> str: 985 return self.func( 986 name, 987 expression.this, 988 expression.args.get("true"), 989 expression.args.get("false") or false_value, 990 ) 991 992 return _if_sql 993 994 995def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 996 this = expression.this 997 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 998 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 999 1000 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 1001 1002 1003def inline_array_sql(self: Generator, expression: exp.Array) -> str: 1004 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 1005 1006 1007def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 1008 elem = seq_get(expression.expressions, 0) 1009 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 1010 return self.func("ARRAY", elem) 1011 return inline_array_sql(self, expression) 1012 1013 1014def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1015 return self.like_sql( 1016 exp.Like( 1017 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1018 ) 1019 ) 1020 1021 1022def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1023 zone = self.sql(expression, "this") 1024 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1025 1026 1027def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1028 if expression.args.get("recursive"): 1029 self.unsupported("Recursive CTEs are unsupported") 1030 expression.args["recursive"] = False 1031 return self.with_sql(expression) 1032 1033 1034def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide, if_sql: str = "IF") -> str: 1035 n = self.sql(expression, "this") 1036 d = self.sql(expression, "expression") 1037 return f"{if_sql}(({d}) <> 0, ({n}) / ({d}), NULL)" 1038 1039 1040def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1041 self.unsupported("TABLESAMPLE unsupported") 1042 return self.sql(expression.this) 1043 1044 1045def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1046 self.unsupported("PIVOT unsupported") 1047 return "" 1048 1049 1050def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1051 return self.cast_sql(expression) 1052 1053 1054def no_comment_column_constraint_sql( 1055 self: Generator, expression: exp.CommentColumnConstraint 1056) -> str: 1057 self.unsupported("CommentColumnConstraint unsupported") 1058 return "" 1059 1060 1061def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1062 self.unsupported("MAP_FROM_ENTRIES unsupported") 1063 return "" 1064 1065 1066def property_sql(self: Generator, expression: exp.Property) -> str: 1067 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1068 1069 1070def str_position_sql( 1071 self: Generator, 1072 expression: exp.StrPosition, 1073 generate_instance: bool = False, 1074 str_position_func_name: str = "STRPOS", 1075) -> str: 1076 this = self.sql(expression, "this") 1077 substr = self.sql(expression, "substr") 1078 position = self.sql(expression, "position") 1079 instance = expression.args.get("instance") if generate_instance else None 1080 position_offset = "" 1081 1082 if position: 1083 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1084 this = self.func("SUBSTR", this, position) 1085 position_offset = f" + {position} - 1" 1086 1087 return self.func(str_position_func_name, this, substr, instance) + position_offset 1088 1089 1090def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1091 return ( 1092 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1093 ) 1094 1095 1096def var_map_sql( 1097 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1098) -> str: 1099 keys = expression.args["keys"] 1100 values = expression.args["values"] 1101 1102 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1103 self.unsupported("Cannot convert array columns into map.") 1104 return self.func(map_func_name, keys, values) 1105 1106 args = [] 1107 for key, value in zip(keys.expressions, values.expressions): 1108 args.append(self.sql(key)) 1109 args.append(self.sql(value)) 1110 1111 return self.func(map_func_name, *args) 1112 1113 1114def build_formatted_time( 1115 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1116) -> t.Callable[[t.List], E]: 1117 """Helper used for time expressions. 1118 1119 Args: 1120 exp_class: the expression class to instantiate. 1121 dialect: target sql dialect. 1122 default: the default format, True being time. 1123 1124 Returns: 1125 A callable that can be used to return the appropriately formatted time expression. 1126 """ 1127 1128 def _builder(args: t.List): 1129 return exp_class( 1130 this=seq_get(args, 0), 1131 format=Dialect[dialect].format_time( 1132 seq_get(args, 1) 1133 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1134 ), 1135 ) 1136 1137 return _builder 1138 1139 1140def time_format( 1141 dialect: DialectType = None, 1142) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1143 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1144 """ 1145 Returns the time format for a given expression, unless it's equivalent 1146 to the default time format of the dialect of interest. 1147 """ 1148 time_format = self.format_time(expression) 1149 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1150 1151 return _time_format 1152 1153 1154def build_date_delta( 1155 exp_class: t.Type[E], 1156 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1157 default_unit: t.Optional[str] = "DAY", 1158) -> t.Callable[[t.List], E]: 1159 def _builder(args: t.List) -> E: 1160 unit_based = len(args) == 3 1161 this = args[2] if unit_based else seq_get(args, 0) 1162 unit = None 1163 if unit_based or default_unit: 1164 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1165 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1166 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1167 1168 return _builder 1169 1170 1171def build_date_delta_with_interval( 1172 expression_class: t.Type[E], 1173) -> t.Callable[[t.List], t.Optional[E]]: 1174 def _builder(args: t.List) -> t.Optional[E]: 1175 if len(args) < 2: 1176 return None 1177 1178 interval = args[1] 1179 1180 if not isinstance(interval, exp.Interval): 1181 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1182 1183 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1184 1185 return _builder 1186 1187 1188def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1189 unit = seq_get(args, 0) 1190 this = seq_get(args, 1) 1191 1192 if isinstance(this, exp.Cast) and this.is_type("date"): 1193 return exp.DateTrunc(unit=unit, this=this) 1194 return exp.TimestampTrunc(this=this, unit=unit) 1195 1196 1197def date_add_interval_sql( 1198 data_type: str, kind: str 1199) -> t.Callable[[Generator, exp.Expression], str]: 1200 def func(self: Generator, expression: exp.Expression) -> str: 1201 this = self.sql(expression, "this") 1202 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1203 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1204 1205 return func 1206 1207 1208def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1209 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1210 args = [unit_to_str(expression), expression.this] 1211 if zone: 1212 args.append(expression.args.get("zone")) 1213 return self.func("DATE_TRUNC", *args) 1214 1215 return _timestamptrunc_sql 1216 1217 1218def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1219 zone = expression.args.get("zone") 1220 if not zone: 1221 from sqlglot.optimizer.annotate_types import annotate_types 1222 1223 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1224 return self.sql(exp.cast(expression.this, target_type)) 1225 if zone.name.lower() in TIMEZONES: 1226 return self.sql( 1227 exp.AtTimeZone( 1228 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1229 zone=zone, 1230 ) 1231 ) 1232 return self.func("TIMESTAMP", expression.this, zone) 1233 1234 1235def no_time_sql(self: Generator, expression: exp.Time) -> str: 1236 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1237 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1238 expr = exp.cast( 1239 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1240 ) 1241 return self.sql(expr) 1242 1243 1244def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1245 this = expression.this 1246 expr = expression.expression 1247 1248 if expr.name.lower() in TIMEZONES: 1249 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1250 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1251 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1252 return self.sql(this) 1253 1254 this = exp.cast(this, exp.DataType.Type.DATE) 1255 expr = exp.cast(expr, exp.DataType.Type.TIME) 1256 1257 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1258 1259 1260def locate_to_strposition(args: t.List) -> exp.Expression: 1261 return exp.StrPosition( 1262 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1263 ) 1264 1265 1266def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1267 return self.func( 1268 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1269 ) 1270 1271 1272def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1273 return self.sql( 1274 exp.Substring( 1275 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1276 ) 1277 ) 1278 1279 1280def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1281 return self.sql( 1282 exp.Substring( 1283 this=expression.this, 1284 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1285 ) 1286 ) 1287 1288 1289def timestrtotime_sql( 1290 self: Generator, 1291 expression: exp.TimeStrToTime, 1292 include_precision: bool = False, 1293) -> str: 1294 datatype = exp.DataType.build( 1295 exp.DataType.Type.TIMESTAMPTZ 1296 if expression.args.get("zone") 1297 else exp.DataType.Type.TIMESTAMP 1298 ) 1299 1300 if isinstance(expression.this, exp.Literal) and include_precision: 1301 precision = subsecond_precision(expression.this.name) 1302 if precision > 0: 1303 datatype = exp.DataType.build( 1304 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1305 ) 1306 1307 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1308 1309 1310def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1311 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1312 1313 1314# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1315def encode_decode_sql( 1316 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1317) -> str: 1318 charset = expression.args.get("charset") 1319 if charset and charset.name.lower() != "utf-8": 1320 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1321 1322 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1323 1324 1325def min_or_least(self: Generator, expression: exp.Min) -> str: 1326 name = "LEAST" if expression.expressions else "MIN" 1327 return rename_func(name)(self, expression) 1328 1329 1330def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1331 name = "GREATEST" if expression.expressions else "MAX" 1332 return rename_func(name)(self, expression) 1333 1334 1335def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1336 cond = expression.this 1337 1338 if isinstance(expression.this, exp.Distinct): 1339 cond = expression.this.expressions[0] 1340 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1341 1342 return self.func("sum", exp.func("if", cond, 1, 0)) 1343 1344 1345def trim_sql(self: Generator, expression: exp.Trim) -> str: 1346 target = self.sql(expression, "this") 1347 trim_type = self.sql(expression, "position") 1348 remove_chars = self.sql(expression, "expression") 1349 collation = self.sql(expression, "collation") 1350 1351 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1352 if not remove_chars: 1353 return self.trim_sql(expression) 1354 1355 trim_type = f"{trim_type} " if trim_type else "" 1356 remove_chars = f"{remove_chars} " if remove_chars else "" 1357 from_part = "FROM " if trim_type or remove_chars else "" 1358 collation = f" COLLATE {collation}" if collation else "" 1359 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1360 1361 1362def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1363 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1364 1365 1366def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1367 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1368 1369 1370def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1371 delim, *rest_args = expression.expressions 1372 return self.sql( 1373 reduce( 1374 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1375 rest_args, 1376 ) 1377 ) 1378 1379 1380@unsupported_args("position", "occurrence", "parameters") 1381def regexp_extract_sql( 1382 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1383) -> str: 1384 group = expression.args.get("group") 1385 1386 # Do not render group if it's the default value for this dialect 1387 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1388 group = None 1389 1390 return self.func(expression.sql_name(), expression.this, expression.expression, group) 1391 1392 1393@unsupported_args("position", "occurrence", "modifiers") 1394def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1395 return self.func( 1396 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1397 ) 1398 1399 1400def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1401 names = [] 1402 for agg in aggregations: 1403 if isinstance(agg, exp.Alias): 1404 names.append(agg.alias) 1405 else: 1406 """ 1407 This case corresponds to aggregations without aliases being used as suffixes 1408 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1409 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1410 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1411 """ 1412 agg_all_unquoted = agg.transform( 1413 lambda node: ( 1414 exp.Identifier(this=node.name, quoted=False) 1415 if isinstance(node, exp.Identifier) 1416 else node 1417 ) 1418 ) 1419 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1420 1421 return names 1422 1423 1424def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1425 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1426 1427 1428# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1429def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1430 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1431 1432 1433def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1434 return self.func("MAX", expression.this) 1435 1436 1437def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1438 a = self.sql(expression.left) 1439 b = self.sql(expression.right) 1440 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1441 1442 1443def is_parse_json(expression: exp.Expression) -> bool: 1444 return isinstance(expression, exp.ParseJSON) or ( 1445 isinstance(expression, exp.Cast) and expression.is_type("json") 1446 ) 1447 1448 1449def isnull_to_is_null(args: t.List) -> exp.Expression: 1450 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1451 1452 1453def generatedasidentitycolumnconstraint_sql( 1454 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1455) -> str: 1456 start = self.sql(expression, "start") or "1" 1457 increment = self.sql(expression, "increment") or "1" 1458 return f"IDENTITY({start}, {increment})" 1459 1460 1461def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1462 @unsupported_args("count") 1463 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1464 return self.func(name, expression.this, expression.expression) 1465 1466 return _arg_max_or_min_sql 1467 1468 1469def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1470 this = expression.this.copy() 1471 1472 return_type = expression.return_type 1473 if return_type.is_type(exp.DataType.Type.DATE): 1474 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1475 # can truncate timestamp strings, because some dialects can't cast them to DATE 1476 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1477 1478 expression.this.replace(exp.cast(this, return_type)) 1479 return expression 1480 1481 1482def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1483 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1484 if cast and isinstance(expression, exp.TsOrDsAdd): 1485 expression = ts_or_ds_add_cast(expression) 1486 1487 return self.func( 1488 name, 1489 unit_to_var(expression), 1490 expression.expression, 1491 expression.this, 1492 ) 1493 1494 return _delta_sql 1495 1496 1497def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1498 unit = expression.args.get("unit") 1499 1500 if isinstance(unit, exp.Placeholder): 1501 return unit 1502 if unit: 1503 return exp.Literal.string(unit.name) 1504 return exp.Literal.string(default) if default else None 1505 1506 1507def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1508 unit = expression.args.get("unit") 1509 1510 if isinstance(unit, (exp.Var, exp.Placeholder)): 1511 return unit 1512 return exp.Var(this=default) if default else None 1513 1514 1515@t.overload 1516def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1517 pass 1518 1519 1520@t.overload 1521def map_date_part( 1522 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1523) -> t.Optional[exp.Expression]: 1524 pass 1525 1526 1527def map_date_part(part, dialect: DialectType = Dialect): 1528 mapped = ( 1529 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1530 ) 1531 return exp.var(mapped) if mapped else part 1532 1533 1534def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1535 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1536 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1537 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1538 1539 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1540 1541 1542def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1543 """Remove table refs from columns in when statements.""" 1544 alias = expression.this.args.get("alias") 1545 1546 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1547 return self.dialect.normalize_identifier(identifier).name if identifier else None 1548 1549 targets = {normalize(expression.this.this)} 1550 1551 if alias: 1552 targets.add(normalize(alias.this)) 1553 1554 for when in expression.expressions: 1555 # only remove the target names from the THEN clause 1556 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1557 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1558 then = when.args.get("then") 1559 if then: 1560 then.transform( 1561 lambda node: ( 1562 exp.column(node.this) 1563 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1564 else node 1565 ), 1566 copy=False, 1567 ) 1568 1569 return self.merge_sql(expression) 1570 1571 1572def build_json_extract_path( 1573 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1574) -> t.Callable[[t.List], F]: 1575 def _builder(args: t.List) -> F: 1576 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1577 for arg in args[1:]: 1578 if not isinstance(arg, exp.Literal): 1579 # We use the fallback parser because we can't really transpile non-literals safely 1580 return expr_type.from_arg_list(args) 1581 1582 text = arg.name 1583 if is_int(text): 1584 index = int(text) 1585 segments.append( 1586 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1587 ) 1588 else: 1589 segments.append(exp.JSONPathKey(this=text)) 1590 1591 # This is done to avoid failing in the expression validator due to the arg count 1592 del args[2:] 1593 return expr_type( 1594 this=seq_get(args, 0), 1595 expression=exp.JSONPath(expressions=segments), 1596 only_json_types=arrow_req_json_type, 1597 ) 1598 1599 return _builder 1600 1601 1602def json_extract_segments( 1603 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1604) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1605 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1606 path = expression.expression 1607 if not isinstance(path, exp.JSONPath): 1608 return rename_func(name)(self, expression) 1609 1610 escape = path.args.get("escape") 1611 1612 segments = [] 1613 for segment in path.expressions: 1614 path = self.sql(segment) 1615 if path: 1616 if isinstance(segment, exp.JSONPathPart) and ( 1617 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1618 ): 1619 if escape: 1620 path = self.escape_str(path) 1621 1622 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1623 1624 segments.append(path) 1625 1626 if op: 1627 return f" {op} ".join([self.sql(expression.this), *segments]) 1628 return self.func(name, expression.this, *segments) 1629 1630 return _json_extract_segments 1631 1632 1633def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1634 if isinstance(expression.this, exp.JSONPathWildcard): 1635 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1636 1637 return expression.name 1638 1639 1640def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1641 cond = expression.expression 1642 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1643 alias = cond.expressions[0] 1644 cond = cond.this 1645 elif isinstance(cond, exp.Predicate): 1646 alias = "_u" 1647 else: 1648 self.unsupported("Unsupported filter condition") 1649 return "" 1650 1651 unnest = exp.Unnest(expressions=[expression.this]) 1652 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1653 return self.sql(exp.Array(expressions=[filtered])) 1654 1655 1656def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1657 return self.func( 1658 "TO_NUMBER", 1659 expression.this, 1660 expression.args.get("format"), 1661 expression.args.get("nlsparam"), 1662 ) 1663 1664 1665def build_default_decimal_type( 1666 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1667) -> t.Callable[[exp.DataType], exp.DataType]: 1668 def _builder(dtype: exp.DataType) -> exp.DataType: 1669 if dtype.expressions or precision is None: 1670 return dtype 1671 1672 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1673 return exp.DataType.build(f"DECIMAL({params})") 1674 1675 return _builder 1676 1677 1678def build_timestamp_from_parts(args: t.List) -> exp.Func: 1679 if len(args) == 2: 1680 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1681 # so we parse this into Anonymous for now instead of introducing complexity 1682 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1683 1684 return exp.TimestampFromParts.from_arg_list(args) 1685 1686 1687def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1688 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1689 1690 1691def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1692 start = expression.args.get("start") 1693 end = expression.args.get("end") 1694 step = expression.args.get("step") 1695 1696 if isinstance(start, exp.Cast): 1697 target_type = start.to 1698 elif isinstance(end, exp.Cast): 1699 target_type = end.to 1700 else: 1701 target_type = None 1702 1703 if start and end and target_type and target_type.is_type("date", "timestamp"): 1704 if isinstance(start, exp.Cast) and target_type is start.to: 1705 end = exp.cast(end, target_type) 1706 else: 1707 start = exp.cast(start, target_type) 1708 1709 return self.func("SEQUENCE", start, end, step) 1710 1711 1712def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1713 def _builder(args: t.List, dialect: Dialect) -> E: 1714 return expr_type( 1715 this=seq_get(args, 0), 1716 expression=seq_get(args, 1), 1717 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1718 parameters=seq_get(args, 3), 1719 ) 1720 1721 return _builder 1722 1723 1724def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1725 if isinstance(expression.this, exp.Explode): 1726 return self.sql( 1727 exp.Join( 1728 this=exp.Unnest( 1729 expressions=[expression.this.this], 1730 alias=expression.args.get("alias"), 1731 offset=isinstance(expression.this, exp.Posexplode), 1732 ), 1733 kind="cross", 1734 ) 1735 ) 1736 return self.lateral_sql(expression) 1737 1738 1739def timestampdiff_sql(self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff) -> str: 1740 return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) 1741 1742 1743def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1744 args = [] 1745 for unit, value in expression.args.items(): 1746 if isinstance(value, exp.Kwarg): 1747 value = value.expression 1748 1749 args.append(f"{value} {unit}") 1750 1751 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
219class Dialect(metaclass=_Dialect): 220 INDEX_OFFSET = 0 221 """The base index offset for arrays.""" 222 223 WEEK_OFFSET = 0 224 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 225 226 UNNEST_COLUMN_ONLY = False 227 """Whether `UNNEST` table aliases are treated as column aliases.""" 228 229 ALIAS_POST_TABLESAMPLE = False 230 """Whether the table alias comes after tablesample.""" 231 232 TABLESAMPLE_SIZE_IS_PERCENT = False 233 """Whether a size in the table sample clause represents percentage.""" 234 235 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 236 """Specifies the strategy according to which identifiers should be normalized.""" 237 238 IDENTIFIERS_CAN_START_WITH_DIGIT = False 239 """Whether an unquoted identifier can start with a digit.""" 240 241 DPIPE_IS_STRING_CONCAT = True 242 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 243 244 STRICT_STRING_CONCAT = False 245 """Whether `CONCAT`'s arguments must be strings.""" 246 247 SUPPORTS_USER_DEFINED_TYPES = True 248 """Whether user-defined data types are supported.""" 249 250 SUPPORTS_SEMI_ANTI_JOIN = True 251 """Whether `SEMI` or `ANTI` joins are supported.""" 252 253 SUPPORTS_COLUMN_JOIN_MARKS = False 254 """Whether the old-style outer join (+) syntax is supported.""" 255 256 COPY_PARAMS_ARE_CSV = True 257 """Separator of COPY statement parameters.""" 258 259 NORMALIZE_FUNCTIONS: bool | str = "upper" 260 """ 261 Determines how function names are going to be normalized. 262 Possible values: 263 "upper" or True: Convert names to uppercase. 264 "lower": Convert names to lowercase. 265 False: Disables function name normalization. 266 """ 267 268 PRESERVE_ORIGINAL_NAMES: bool = False 269 """ 270 Whether the name of the function should be preserved inside the node's metadata, 271 can be useful for roundtripping deprecated vs new functions that share an AST node 272 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 273 """ 274 275 LOG_BASE_FIRST: t.Optional[bool] = True 276 """ 277 Whether the base comes first in the `LOG` function. 278 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 279 """ 280 281 NULL_ORDERING = "nulls_are_small" 282 """ 283 Default `NULL` ordering method to use if not explicitly set. 284 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 285 """ 286 287 TYPED_DIVISION = False 288 """ 289 Whether the behavior of `a / b` depends on the types of `a` and `b`. 290 False means `a / b` is always float division. 291 True means `a / b` is integer division if both `a` and `b` are integers. 292 """ 293 294 SAFE_DIVISION = False 295 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 296 297 CONCAT_COALESCE = False 298 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 299 300 HEX_LOWERCASE = False 301 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 302 303 DATE_FORMAT = "'%Y-%m-%d'" 304 DATEINT_FORMAT = "'%Y%m%d'" 305 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 306 307 TIME_MAPPING: t.Dict[str, str] = {} 308 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 309 310 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 311 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 312 FORMAT_MAPPING: t.Dict[str, str] = {} 313 """ 314 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 315 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 316 """ 317 318 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 319 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 320 321 PSEUDOCOLUMNS: t.Set[str] = set() 322 """ 323 Columns that are auto-generated by the engine corresponding to this dialect. 324 For example, such columns may be excluded from `SELECT *` queries. 325 """ 326 327 PREFER_CTE_ALIAS_COLUMN = False 328 """ 329 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 330 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 331 any projection aliases in the subquery. 332 333 For example, 334 WITH y(c) AS ( 335 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 336 ) SELECT c FROM y; 337 338 will be rewritten as 339 340 WITH y(c) AS ( 341 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 342 ) SELECT c FROM y; 343 """ 344 345 COPY_PARAMS_ARE_CSV = True 346 """ 347 Whether COPY statement parameters are separated by comma or whitespace 348 """ 349 350 FORCE_EARLY_ALIAS_REF_EXPANSION = False 351 """ 352 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 353 354 For example: 355 WITH data AS ( 356 SELECT 357 1 AS id, 358 2 AS my_id 359 ) 360 SELECT 361 id AS my_id 362 FROM 363 data 364 WHERE 365 my_id = 1 366 GROUP BY 367 my_id, 368 HAVING 369 my_id = 1 370 371 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 372 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 373 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 374 - Clickhouse, which will forward the alias across the query i.e it resolves 375 to "WHERE id = 1 GROUP BY id HAVING id = 1" 376 """ 377 378 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 379 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 380 381 SUPPORTS_ORDER_BY_ALL = False 382 """ 383 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 384 """ 385 386 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 387 """ 388 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 389 as the former is of type INT[] vs the latter which is SUPER 390 """ 391 392 SUPPORTS_FIXED_SIZE_ARRAYS = False 393 """ 394 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 395 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 396 be interpreted as a subscript/index operator. 397 """ 398 399 STRICT_JSON_PATH_SYNTAX = True 400 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 401 402 ON_CONDITION_EMPTY_BEFORE_ERROR = True 403 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 404 405 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 406 """Whether ArrayAgg needs to filter NULL values.""" 407 408 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 409 """ 410 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 411 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 412 is cast to x's type to match it instead. 413 """ 414 415 REGEXP_EXTRACT_DEFAULT_GROUP = 0 416 """The default value for the capturing group.""" 417 418 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 419 exp.Except: True, 420 exp.Intersect: True, 421 exp.Union: True, 422 } 423 """ 424 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 425 must be explicitly specified. 426 """ 427 428 CREATABLE_KIND_MAPPING: dict[str, str] = {} 429 """ 430 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 431 equivalent of CREATE SCHEMA is CREATE DATABASE. 432 """ 433 434 # --- Autofilled --- 435 436 tokenizer_class = Tokenizer 437 jsonpath_tokenizer_class = JSONPathTokenizer 438 parser_class = Parser 439 generator_class = Generator 440 441 # A trie of the time_mapping keys 442 TIME_TRIE: t.Dict = {} 443 FORMAT_TRIE: t.Dict = {} 444 445 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 446 INVERSE_TIME_TRIE: t.Dict = {} 447 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 448 INVERSE_FORMAT_TRIE: t.Dict = {} 449 450 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 451 452 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 453 454 # Delimiters for string literals and identifiers 455 QUOTE_START = "'" 456 QUOTE_END = "'" 457 IDENTIFIER_START = '"' 458 IDENTIFIER_END = '"' 459 460 # Delimiters for bit, hex, byte and unicode literals 461 BIT_START: t.Optional[str] = None 462 BIT_END: t.Optional[str] = None 463 HEX_START: t.Optional[str] = None 464 HEX_END: t.Optional[str] = None 465 BYTE_START: t.Optional[str] = None 466 BYTE_END: t.Optional[str] = None 467 UNICODE_START: t.Optional[str] = None 468 UNICODE_END: t.Optional[str] = None 469 470 DATE_PART_MAPPING = { 471 "Y": "YEAR", 472 "YY": "YEAR", 473 "YYY": "YEAR", 474 "YYYY": "YEAR", 475 "YR": "YEAR", 476 "YEARS": "YEAR", 477 "YRS": "YEAR", 478 "MM": "MONTH", 479 "MON": "MONTH", 480 "MONS": "MONTH", 481 "MONTHS": "MONTH", 482 "D": "DAY", 483 "DD": "DAY", 484 "DAYS": "DAY", 485 "DAYOFMONTH": "DAY", 486 "DAY OF WEEK": "DAYOFWEEK", 487 "WEEKDAY": "DAYOFWEEK", 488 "DOW": "DAYOFWEEK", 489 "DW": "DAYOFWEEK", 490 "WEEKDAY_ISO": "DAYOFWEEKISO", 491 "DOW_ISO": "DAYOFWEEKISO", 492 "DW_ISO": "DAYOFWEEKISO", 493 "DAY OF YEAR": "DAYOFYEAR", 494 "DOY": "DAYOFYEAR", 495 "DY": "DAYOFYEAR", 496 "W": "WEEK", 497 "WK": "WEEK", 498 "WEEKOFYEAR": "WEEK", 499 "WOY": "WEEK", 500 "WY": "WEEK", 501 "WEEK_ISO": "WEEKISO", 502 "WEEKOFYEARISO": "WEEKISO", 503 "WEEKOFYEAR_ISO": "WEEKISO", 504 "Q": "QUARTER", 505 "QTR": "QUARTER", 506 "QTRS": "QUARTER", 507 "QUARTERS": "QUARTER", 508 "H": "HOUR", 509 "HH": "HOUR", 510 "HR": "HOUR", 511 "HOURS": "HOUR", 512 "HRS": "HOUR", 513 "M": "MINUTE", 514 "MI": "MINUTE", 515 "MIN": "MINUTE", 516 "MINUTES": "MINUTE", 517 "MINS": "MINUTE", 518 "S": "SECOND", 519 "SEC": "SECOND", 520 "SECONDS": "SECOND", 521 "SECS": "SECOND", 522 "MS": "MILLISECOND", 523 "MSEC": "MILLISECOND", 524 "MSECS": "MILLISECOND", 525 "MSECOND": "MILLISECOND", 526 "MSECONDS": "MILLISECOND", 527 "MILLISEC": "MILLISECOND", 528 "MILLISECS": "MILLISECOND", 529 "MILLISECON": "MILLISECOND", 530 "MILLISECONDS": "MILLISECOND", 531 "US": "MICROSECOND", 532 "USEC": "MICROSECOND", 533 "USECS": "MICROSECOND", 534 "MICROSEC": "MICROSECOND", 535 "MICROSECS": "MICROSECOND", 536 "USECOND": "MICROSECOND", 537 "USECONDS": "MICROSECOND", 538 "MICROSECONDS": "MICROSECOND", 539 "NS": "NANOSECOND", 540 "NSEC": "NANOSECOND", 541 "NANOSEC": "NANOSECOND", 542 "NSECOND": "NANOSECOND", 543 "NSECONDS": "NANOSECOND", 544 "NANOSECS": "NANOSECOND", 545 "EPOCH_SECOND": "EPOCH", 546 "EPOCH_SECONDS": "EPOCH", 547 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 548 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 549 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 550 "TZH": "TIMEZONE_HOUR", 551 "TZM": "TIMEZONE_MINUTE", 552 "DEC": "DECADE", 553 "DECS": "DECADE", 554 "DECADES": "DECADE", 555 "MIL": "MILLENIUM", 556 "MILS": "MILLENIUM", 557 "MILLENIA": "MILLENIUM", 558 "C": "CENTURY", 559 "CENT": "CENTURY", 560 "CENTS": "CENTURY", 561 "CENTURIES": "CENTURY", 562 } 563 564 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 565 exp.DataType.Type.BIGINT: { 566 exp.ApproxDistinct, 567 exp.ArraySize, 568 exp.Length, 569 }, 570 exp.DataType.Type.BOOLEAN: { 571 exp.Between, 572 exp.Boolean, 573 exp.In, 574 exp.RegexpLike, 575 }, 576 exp.DataType.Type.DATE: { 577 exp.CurrentDate, 578 exp.Date, 579 exp.DateFromParts, 580 exp.DateStrToDate, 581 exp.DiToDate, 582 exp.StrToDate, 583 exp.TimeStrToDate, 584 exp.TsOrDsToDate, 585 }, 586 exp.DataType.Type.DATETIME: { 587 exp.CurrentDatetime, 588 exp.Datetime, 589 exp.DatetimeAdd, 590 exp.DatetimeSub, 591 }, 592 exp.DataType.Type.DOUBLE: { 593 exp.ApproxQuantile, 594 exp.Avg, 595 exp.Exp, 596 exp.Ln, 597 exp.Log, 598 exp.Pow, 599 exp.Quantile, 600 exp.Round, 601 exp.SafeDivide, 602 exp.Sqrt, 603 exp.Stddev, 604 exp.StddevPop, 605 exp.StddevSamp, 606 exp.ToDouble, 607 exp.Variance, 608 exp.VariancePop, 609 }, 610 exp.DataType.Type.INT: { 611 exp.Ceil, 612 exp.DatetimeDiff, 613 exp.DateDiff, 614 exp.TimestampDiff, 615 exp.TimeDiff, 616 exp.DateToDi, 617 exp.Levenshtein, 618 exp.Sign, 619 exp.StrPosition, 620 exp.TsOrDiToDi, 621 }, 622 exp.DataType.Type.JSON: { 623 exp.ParseJSON, 624 }, 625 exp.DataType.Type.TIME: { 626 exp.Time, 627 }, 628 exp.DataType.Type.TIMESTAMP: { 629 exp.CurrentTime, 630 exp.CurrentTimestamp, 631 exp.StrToTime, 632 exp.TimeAdd, 633 exp.TimeStrToTime, 634 exp.TimeSub, 635 exp.TimestampAdd, 636 exp.TimestampSub, 637 exp.UnixToTime, 638 }, 639 exp.DataType.Type.TINYINT: { 640 exp.Day, 641 exp.Month, 642 exp.Week, 643 exp.Year, 644 exp.Quarter, 645 }, 646 exp.DataType.Type.VARCHAR: { 647 exp.ArrayConcat, 648 exp.Concat, 649 exp.ConcatWs, 650 exp.DateToDateStr, 651 exp.GroupConcat, 652 exp.Initcap, 653 exp.Lower, 654 exp.Substring, 655 exp.String, 656 exp.TimeToStr, 657 exp.TimeToTimeStr, 658 exp.Trim, 659 exp.TsOrDsToDateStr, 660 exp.UnixToStr, 661 exp.UnixToTimeStr, 662 exp.Upper, 663 }, 664 } 665 666 ANNOTATORS: AnnotatorsType = { 667 **{ 668 expr_type: lambda self, e: self._annotate_unary(e) 669 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 670 }, 671 **{ 672 expr_type: lambda self, e: self._annotate_binary(e) 673 for expr_type in subclasses(exp.__name__, exp.Binary) 674 }, 675 **{ 676 expr_type: _annotate_with_type_lambda(data_type) 677 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 678 for expr_type in expressions 679 }, 680 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 681 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 682 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 683 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 684 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 685 exp.Bracket: lambda self, e: self._annotate_bracket(e), 686 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 687 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 688 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 689 exp.Count: lambda self, e: self._annotate_with_type( 690 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 691 ), 692 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 693 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 694 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 695 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 696 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 697 exp.Div: lambda self, e: self._annotate_div(e), 698 exp.Dot: lambda self, e: self._annotate_dot(e), 699 exp.Explode: lambda self, e: self._annotate_explode(e), 700 exp.Extract: lambda self, e: self._annotate_extract(e), 701 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 702 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 703 e, exp.DataType.build("ARRAY<DATE>") 704 ), 705 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 706 e, exp.DataType.build("ARRAY<TIMESTAMP>") 707 ), 708 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 709 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 710 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 711 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 712 exp.Literal: lambda self, e: self._annotate_literal(e), 713 exp.Map: lambda self, e: self._annotate_map(e), 714 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 715 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 716 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 717 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 718 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 719 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 720 exp.Struct: lambda self, e: self._annotate_struct(e), 721 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 722 exp.Timestamp: lambda self, e: self._annotate_with_type( 723 e, 724 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 725 ), 726 exp.ToMap: lambda self, e: self._annotate_to_map(e), 727 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 728 exp.Unnest: lambda self, e: self._annotate_unnest(e), 729 exp.VarMap: lambda self, e: self._annotate_map(e), 730 } 731 732 @classmethod 733 def get_or_raise(cls, dialect: DialectType) -> Dialect: 734 """ 735 Look up a dialect in the global dialect registry and return it if it exists. 736 737 Args: 738 dialect: The target dialect. If this is a string, it can be optionally followed by 739 additional key-value pairs that are separated by commas and are used to specify 740 dialect settings, such as whether the dialect's identifiers are case-sensitive. 741 742 Example: 743 >>> dialect = dialect_class = get_or_raise("duckdb") 744 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 745 746 Returns: 747 The corresponding Dialect instance. 748 """ 749 750 if not dialect: 751 return cls() 752 if isinstance(dialect, _Dialect): 753 return dialect() 754 if isinstance(dialect, Dialect): 755 return dialect 756 if isinstance(dialect, str): 757 try: 758 dialect_name, *kv_strings = dialect.split(",") 759 kv_pairs = (kv.split("=") for kv in kv_strings) 760 kwargs = {} 761 for pair in kv_pairs: 762 key = pair[0].strip() 763 value: t.Union[bool | str | None] = None 764 765 if len(pair) == 1: 766 # Default initialize standalone settings to True 767 value = True 768 elif len(pair) == 2: 769 value = pair[1].strip() 770 771 # Coerce the value to boolean if it matches to the truthy/falsy values below 772 value_lower = value.lower() 773 if value_lower in ("true", "1"): 774 value = True 775 elif value_lower in ("false", "0"): 776 value = False 777 778 kwargs[key] = value 779 780 except ValueError: 781 raise ValueError( 782 f"Invalid dialect format: '{dialect}'. " 783 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 784 ) 785 786 result = cls.get(dialect_name.strip()) 787 if not result: 788 from difflib import get_close_matches 789 790 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 791 if similar: 792 similar = f" Did you mean {similar}?" 793 794 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 795 796 return result(**kwargs) 797 798 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 799 800 @classmethod 801 def format_time( 802 cls, expression: t.Optional[str | exp.Expression] 803 ) -> t.Optional[exp.Expression]: 804 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 805 if isinstance(expression, str): 806 return exp.Literal.string( 807 # the time formats are quoted 808 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 809 ) 810 811 if expression and expression.is_string: 812 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 813 814 return expression 815 816 def __init__(self, **kwargs) -> None: 817 normalization_strategy = kwargs.pop("normalization_strategy", None) 818 819 if normalization_strategy is None: 820 self.normalization_strategy = self.NORMALIZATION_STRATEGY 821 else: 822 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 823 824 self.settings = kwargs 825 826 def __eq__(self, other: t.Any) -> bool: 827 # Does not currently take dialect state into account 828 return type(self) == other 829 830 def __hash__(self) -> int: 831 # Does not currently take dialect state into account 832 return hash(type(self)) 833 834 def normalize_identifier(self, expression: E) -> E: 835 """ 836 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 837 838 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 839 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 840 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 841 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 842 843 There are also dialects like Spark, which are case-insensitive even when quotes are 844 present, and dialects like MySQL, whose resolution rules match those employed by the 845 underlying operating system, for example they may always be case-sensitive in Linux. 846 847 Finally, the normalization behavior of some engines can even be controlled through flags, 848 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 849 850 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 851 that it can analyze queries in the optimizer and successfully capture their semantics. 852 """ 853 if ( 854 isinstance(expression, exp.Identifier) 855 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 856 and ( 857 not expression.quoted 858 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 859 ) 860 ): 861 expression.set( 862 "this", 863 ( 864 expression.this.upper() 865 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 866 else expression.this.lower() 867 ), 868 ) 869 870 return expression 871 872 def case_sensitive(self, text: str) -> bool: 873 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 874 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 875 return False 876 877 unsafe = ( 878 str.islower 879 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 880 else str.isupper 881 ) 882 return any(unsafe(char) for char in text) 883 884 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 885 """Checks if text can be identified given an identify option. 886 887 Args: 888 text: The text to check. 889 identify: 890 `"always"` or `True`: Always returns `True`. 891 `"safe"`: Only returns `True` if the identifier is case-insensitive. 892 893 Returns: 894 Whether the given text can be identified. 895 """ 896 if identify is True or identify == "always": 897 return True 898 899 if identify == "safe": 900 return not self.case_sensitive(text) 901 902 return False 903 904 def quote_identifier(self, expression: E, identify: bool = True) -> E: 905 """ 906 Adds quotes to a given identifier. 907 908 Args: 909 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 910 identify: If set to `False`, the quotes will only be added if the identifier is deemed 911 "unsafe", with respect to its characters and this dialect's normalization strategy. 912 """ 913 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 914 name = expression.this 915 expression.set( 916 "quoted", 917 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 918 ) 919 920 return expression 921 922 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 923 if isinstance(path, exp.Literal): 924 path_text = path.name 925 if path.is_number: 926 path_text = f"[{path_text}]" 927 try: 928 return parse_json_path(path_text, self) 929 except ParseError as e: 930 if self.STRICT_JSON_PATH_SYNTAX: 931 logger.warning(f"Invalid JSON path syntax. {str(e)}") 932 933 return path 934 935 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 936 return self.parser(**opts).parse(self.tokenize(sql), sql) 937 938 def parse_into( 939 self, expression_type: exp.IntoType, sql: str, **opts 940 ) -> t.List[t.Optional[exp.Expression]]: 941 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 942 943 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 944 return self.generator(**opts).generate(expression, copy=copy) 945 946 def transpile(self, sql: str, **opts) -> t.List[str]: 947 return [ 948 self.generate(expression, copy=False, **opts) if expression else "" 949 for expression in self.parse(sql) 950 ] 951 952 def tokenize(self, sql: str) -> t.List[Token]: 953 return self.tokenizer.tokenize(sql) 954 955 @property 956 def tokenizer(self) -> Tokenizer: 957 return self.tokenizer_class(dialect=self) 958 959 @property 960 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 961 return self.jsonpath_tokenizer_class(dialect=self) 962 963 def parser(self, **opts) -> Parser: 964 return self.parser_class(dialect=self, **opts) 965 966 def generator(self, **opts) -> Generator: 967 return self.generator_class(dialect=self, **opts)
816 def __init__(self, **kwargs) -> None: 817 normalization_strategy = kwargs.pop("normalization_strategy", None) 818 819 if normalization_strategy is None: 820 self.normalization_strategy = self.NORMALIZATION_STRATEGY 821 else: 822 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 823 824 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the name of the function should be preserved inside the node's metadata, can be useful for roundtripping deprecated vs new functions that share an AST node e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal is cast to x's type to match it instead.
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
732 @classmethod 733 def get_or_raise(cls, dialect: DialectType) -> Dialect: 734 """ 735 Look up a dialect in the global dialect registry and return it if it exists. 736 737 Args: 738 dialect: The target dialect. If this is a string, it can be optionally followed by 739 additional key-value pairs that are separated by commas and are used to specify 740 dialect settings, such as whether the dialect's identifiers are case-sensitive. 741 742 Example: 743 >>> dialect = dialect_class = get_or_raise("duckdb") 744 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 745 746 Returns: 747 The corresponding Dialect instance. 748 """ 749 750 if not dialect: 751 return cls() 752 if isinstance(dialect, _Dialect): 753 return dialect() 754 if isinstance(dialect, Dialect): 755 return dialect 756 if isinstance(dialect, str): 757 try: 758 dialect_name, *kv_strings = dialect.split(",") 759 kv_pairs = (kv.split("=") for kv in kv_strings) 760 kwargs = {} 761 for pair in kv_pairs: 762 key = pair[0].strip() 763 value: t.Union[bool | str | None] = None 764 765 if len(pair) == 1: 766 # Default initialize standalone settings to True 767 value = True 768 elif len(pair) == 2: 769 value = pair[1].strip() 770 771 # Coerce the value to boolean if it matches to the truthy/falsy values below 772 value_lower = value.lower() 773 if value_lower in ("true", "1"): 774 value = True 775 elif value_lower in ("false", "0"): 776 value = False 777 778 kwargs[key] = value 779 780 except ValueError: 781 raise ValueError( 782 f"Invalid dialect format: '{dialect}'. " 783 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 784 ) 785 786 result = cls.get(dialect_name.strip()) 787 if not result: 788 from difflib import get_close_matches 789 790 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 791 if similar: 792 similar = f" Did you mean {similar}?" 793 794 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 795 796 return result(**kwargs) 797 798 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
800 @classmethod 801 def format_time( 802 cls, expression: t.Optional[str | exp.Expression] 803 ) -> t.Optional[exp.Expression]: 804 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 805 if isinstance(expression, str): 806 return exp.Literal.string( 807 # the time formats are quoted 808 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 809 ) 810 811 if expression and expression.is_string: 812 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 813 814 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
834 def normalize_identifier(self, expression: E) -> E: 835 """ 836 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 837 838 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 839 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 840 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 841 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 842 843 There are also dialects like Spark, which are case-insensitive even when quotes are 844 present, and dialects like MySQL, whose resolution rules match those employed by the 845 underlying operating system, for example they may always be case-sensitive in Linux. 846 847 Finally, the normalization behavior of some engines can even be controlled through flags, 848 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 849 850 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 851 that it can analyze queries in the optimizer and successfully capture their semantics. 852 """ 853 if ( 854 isinstance(expression, exp.Identifier) 855 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 856 and ( 857 not expression.quoted 858 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 859 ) 860 ): 861 expression.set( 862 "this", 863 ( 864 expression.this.upper() 865 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 866 else expression.this.lower() 867 ), 868 ) 869 870 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
872 def case_sensitive(self, text: str) -> bool: 873 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 874 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 875 return False 876 877 unsafe = ( 878 str.islower 879 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 880 else str.isupper 881 ) 882 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
884 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 885 """Checks if text can be identified given an identify option. 886 887 Args: 888 text: The text to check. 889 identify: 890 `"always"` or `True`: Always returns `True`. 891 `"safe"`: Only returns `True` if the identifier is case-insensitive. 892 893 Returns: 894 Whether the given text can be identified. 895 """ 896 if identify is True or identify == "always": 897 return True 898 899 if identify == "safe": 900 return not self.case_sensitive(text) 901 902 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
904 def quote_identifier(self, expression: E, identify: bool = True) -> E: 905 """ 906 Adds quotes to a given identifier. 907 908 Args: 909 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 910 identify: If set to `False`, the quotes will only be added if the identifier is deemed 911 "unsafe", with respect to its characters and this dialect's normalization strategy. 912 """ 913 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 914 name = expression.this 915 expression.set( 916 "quoted", 917 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 918 ) 919 920 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
922 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 923 if isinstance(path, exp.Literal): 924 path_text = path.name 925 if path.is_number: 926 path_text = f"[{path_text}]" 927 try: 928 return parse_json_path(path_text, self) 929 except ParseError as e: 930 if self.STRICT_JSON_PATH_SYNTAX: 931 logger.warning(f"Invalid JSON path syntax. {str(e)}") 932 933 return path
982def if_sql( 983 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 984) -> t.Callable[[Generator, exp.If], str]: 985 def _if_sql(self: Generator, expression: exp.If) -> str: 986 return self.func( 987 name, 988 expression.this, 989 expression.args.get("true"), 990 expression.args.get("false") or false_value, 991 ) 992 993 return _if_sql
996def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 997 this = expression.this 998 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 999 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1000 1001 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1071def str_position_sql( 1072 self: Generator, 1073 expression: exp.StrPosition, 1074 generate_instance: bool = False, 1075 str_position_func_name: str = "STRPOS", 1076) -> str: 1077 this = self.sql(expression, "this") 1078 substr = self.sql(expression, "substr") 1079 position = self.sql(expression, "position") 1080 instance = expression.args.get("instance") if generate_instance else None 1081 position_offset = "" 1082 1083 if position: 1084 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1085 this = self.func("SUBSTR", this, position) 1086 position_offset = f" + {position} - 1" 1087 1088 return self.func(str_position_func_name, this, substr, instance) + position_offset
1097def var_map_sql( 1098 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1099) -> str: 1100 keys = expression.args["keys"] 1101 values = expression.args["values"] 1102 1103 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1104 self.unsupported("Cannot convert array columns into map.") 1105 return self.func(map_func_name, keys, values) 1106 1107 args = [] 1108 for key, value in zip(keys.expressions, values.expressions): 1109 args.append(self.sql(key)) 1110 args.append(self.sql(value)) 1111 1112 return self.func(map_func_name, *args)
1115def build_formatted_time( 1116 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1117) -> t.Callable[[t.List], E]: 1118 """Helper used for time expressions. 1119 1120 Args: 1121 exp_class: the expression class to instantiate. 1122 dialect: target sql dialect. 1123 default: the default format, True being time. 1124 1125 Returns: 1126 A callable that can be used to return the appropriately formatted time expression. 1127 """ 1128 1129 def _builder(args: t.List): 1130 return exp_class( 1131 this=seq_get(args, 0), 1132 format=Dialect[dialect].format_time( 1133 seq_get(args, 1) 1134 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1135 ), 1136 ) 1137 1138 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1141def time_format( 1142 dialect: DialectType = None, 1143) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1144 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1145 """ 1146 Returns the time format for a given expression, unless it's equivalent 1147 to the default time format of the dialect of interest. 1148 """ 1149 time_format = self.format_time(expression) 1150 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1151 1152 return _time_format
1155def build_date_delta( 1156 exp_class: t.Type[E], 1157 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1158 default_unit: t.Optional[str] = "DAY", 1159) -> t.Callable[[t.List], E]: 1160 def _builder(args: t.List) -> E: 1161 unit_based = len(args) == 3 1162 this = args[2] if unit_based else seq_get(args, 0) 1163 unit = None 1164 if unit_based or default_unit: 1165 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1166 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1167 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1168 1169 return _builder
1172def build_date_delta_with_interval( 1173 expression_class: t.Type[E], 1174) -> t.Callable[[t.List], t.Optional[E]]: 1175 def _builder(args: t.List) -> t.Optional[E]: 1176 if len(args) < 2: 1177 return None 1178 1179 interval = args[1] 1180 1181 if not isinstance(interval, exp.Interval): 1182 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1183 1184 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1185 1186 return _builder
1189def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1190 unit = seq_get(args, 0) 1191 this = seq_get(args, 1) 1192 1193 if isinstance(this, exp.Cast) and this.is_type("date"): 1194 return exp.DateTrunc(unit=unit, this=this) 1195 return exp.TimestampTrunc(this=this, unit=unit)
1198def date_add_interval_sql( 1199 data_type: str, kind: str 1200) -> t.Callable[[Generator, exp.Expression], str]: 1201 def func(self: Generator, expression: exp.Expression) -> str: 1202 this = self.sql(expression, "this") 1203 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1204 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1205 1206 return func
1209def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1210 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1211 args = [unit_to_str(expression), expression.this] 1212 if zone: 1213 args.append(expression.args.get("zone")) 1214 return self.func("DATE_TRUNC", *args) 1215 1216 return _timestamptrunc_sql
1219def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1220 zone = expression.args.get("zone") 1221 if not zone: 1222 from sqlglot.optimizer.annotate_types import annotate_types 1223 1224 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1225 return self.sql(exp.cast(expression.this, target_type)) 1226 if zone.name.lower() in TIMEZONES: 1227 return self.sql( 1228 exp.AtTimeZone( 1229 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1230 zone=zone, 1231 ) 1232 ) 1233 return self.func("TIMESTAMP", expression.this, zone)
1236def no_time_sql(self: Generator, expression: exp.Time) -> str: 1237 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1238 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1239 expr = exp.cast( 1240 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1241 ) 1242 return self.sql(expr)
1245def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1246 this = expression.this 1247 expr = expression.expression 1248 1249 if expr.name.lower() in TIMEZONES: 1250 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1251 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1252 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1253 return self.sql(this) 1254 1255 this = exp.cast(this, exp.DataType.Type.DATE) 1256 expr = exp.cast(expr, exp.DataType.Type.TIME) 1257 1258 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1290def timestrtotime_sql( 1291 self: Generator, 1292 expression: exp.TimeStrToTime, 1293 include_precision: bool = False, 1294) -> str: 1295 datatype = exp.DataType.build( 1296 exp.DataType.Type.TIMESTAMPTZ 1297 if expression.args.get("zone") 1298 else exp.DataType.Type.TIMESTAMP 1299 ) 1300 1301 if isinstance(expression.this, exp.Literal) and include_precision: 1302 precision = subsecond_precision(expression.this.name) 1303 if precision > 0: 1304 datatype = exp.DataType.build( 1305 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1306 ) 1307 1308 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1316def encode_decode_sql( 1317 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1318) -> str: 1319 charset = expression.args.get("charset") 1320 if charset and charset.name.lower() != "utf-8": 1321 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1322 1323 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1336def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1337 cond = expression.this 1338 1339 if isinstance(expression.this, exp.Distinct): 1340 cond = expression.this.expressions[0] 1341 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1342 1343 return self.func("sum", exp.func("if", cond, 1, 0))
1346def trim_sql(self: Generator, expression: exp.Trim) -> str: 1347 target = self.sql(expression, "this") 1348 trim_type = self.sql(expression, "position") 1349 remove_chars = self.sql(expression, "expression") 1350 collation = self.sql(expression, "collation") 1351 1352 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1353 if not remove_chars: 1354 return self.trim_sql(expression) 1355 1356 trim_type = f"{trim_type} " if trim_type else "" 1357 remove_chars = f"{remove_chars} " if remove_chars else "" 1358 from_part = "FROM " if trim_type or remove_chars else "" 1359 collation = f" COLLATE {collation}" if collation else "" 1360 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1381@unsupported_args("position", "occurrence", "parameters") 1382def regexp_extract_sql( 1383 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1384) -> str: 1385 group = expression.args.get("group") 1386 1387 # Do not render group if it's the default value for this dialect 1388 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1389 group = None 1390 1391 return self.func(expression.sql_name(), expression.this, expression.expression, group)
1401def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1402 names = [] 1403 for agg in aggregations: 1404 if isinstance(agg, exp.Alias): 1405 names.append(agg.alias) 1406 else: 1407 """ 1408 This case corresponds to aggregations without aliases being used as suffixes 1409 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1410 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1411 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1412 """ 1413 agg_all_unquoted = agg.transform( 1414 lambda node: ( 1415 exp.Identifier(this=node.name, quoted=False) 1416 if isinstance(node, exp.Identifier) 1417 else node 1418 ) 1419 ) 1420 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1421 1422 return names
1462def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1463 @unsupported_args("count") 1464 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1465 return self.func(name, expression.this, expression.expression) 1466 1467 return _arg_max_or_min_sql
1470def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1471 this = expression.this.copy() 1472 1473 return_type = expression.return_type 1474 if return_type.is_type(exp.DataType.Type.DATE): 1475 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1476 # can truncate timestamp strings, because some dialects can't cast them to DATE 1477 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1478 1479 expression.this.replace(exp.cast(this, return_type)) 1480 return expression
1483def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1484 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1485 if cast and isinstance(expression, exp.TsOrDsAdd): 1486 expression = ts_or_ds_add_cast(expression) 1487 1488 return self.func( 1489 name, 1490 unit_to_var(expression), 1491 expression.expression, 1492 expression.this, 1493 ) 1494 1495 return _delta_sql
1498def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1499 unit = expression.args.get("unit") 1500 1501 if isinstance(unit, exp.Placeholder): 1502 return unit 1503 if unit: 1504 return exp.Literal.string(unit.name) 1505 return exp.Literal.string(default) if default else None
1535def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1536 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1537 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1538 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1539 1540 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1543def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1544 """Remove table refs from columns in when statements.""" 1545 alias = expression.this.args.get("alias") 1546 1547 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1548 return self.dialect.normalize_identifier(identifier).name if identifier else None 1549 1550 targets = {normalize(expression.this.this)} 1551 1552 if alias: 1553 targets.add(normalize(alias.this)) 1554 1555 for when in expression.expressions: 1556 # only remove the target names from the THEN clause 1557 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1558 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1559 then = when.args.get("then") 1560 if then: 1561 then.transform( 1562 lambda node: ( 1563 exp.column(node.this) 1564 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1565 else node 1566 ), 1567 copy=False, 1568 ) 1569 1570 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1573def build_json_extract_path( 1574 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1575) -> t.Callable[[t.List], F]: 1576 def _builder(args: t.List) -> F: 1577 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1578 for arg in args[1:]: 1579 if not isinstance(arg, exp.Literal): 1580 # We use the fallback parser because we can't really transpile non-literals safely 1581 return expr_type.from_arg_list(args) 1582 1583 text = arg.name 1584 if is_int(text): 1585 index = int(text) 1586 segments.append( 1587 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1588 ) 1589 else: 1590 segments.append(exp.JSONPathKey(this=text)) 1591 1592 # This is done to avoid failing in the expression validator due to the arg count 1593 del args[2:] 1594 return expr_type( 1595 this=seq_get(args, 0), 1596 expression=exp.JSONPath(expressions=segments), 1597 only_json_types=arrow_req_json_type, 1598 ) 1599 1600 return _builder
1603def json_extract_segments( 1604 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1605) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1606 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1607 path = expression.expression 1608 if not isinstance(path, exp.JSONPath): 1609 return rename_func(name)(self, expression) 1610 1611 escape = path.args.get("escape") 1612 1613 segments = [] 1614 for segment in path.expressions: 1615 path = self.sql(segment) 1616 if path: 1617 if isinstance(segment, exp.JSONPathPart) and ( 1618 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1619 ): 1620 if escape: 1621 path = self.escape_str(path) 1622 1623 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1624 1625 segments.append(path) 1626 1627 if op: 1628 return f" {op} ".join([self.sql(expression.this), *segments]) 1629 return self.func(name, expression.this, *segments) 1630 1631 return _json_extract_segments
1641def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1642 cond = expression.expression 1643 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1644 alias = cond.expressions[0] 1645 cond = cond.this 1646 elif isinstance(cond, exp.Predicate): 1647 alias = "_u" 1648 else: 1649 self.unsupported("Unsupported filter condition") 1650 return "" 1651 1652 unnest = exp.Unnest(expressions=[expression.this]) 1653 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1654 return self.sql(exp.Array(expressions=[filtered]))
1666def build_default_decimal_type( 1667 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1668) -> t.Callable[[exp.DataType], exp.DataType]: 1669 def _builder(dtype: exp.DataType) -> exp.DataType: 1670 if dtype.expressions or precision is None: 1671 return dtype 1672 1673 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1674 return exp.DataType.build(f"DECIMAL({params})") 1675 1676 return _builder
1679def build_timestamp_from_parts(args: t.List) -> exp.Func: 1680 if len(args) == 2: 1681 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1682 # so we parse this into Anonymous for now instead of introducing complexity 1683 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1684 1685 return exp.TimestampFromParts.from_arg_list(args)
1692def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1693 start = expression.args.get("start") 1694 end = expression.args.get("end") 1695 step = expression.args.get("step") 1696 1697 if isinstance(start, exp.Cast): 1698 target_type = start.to 1699 elif isinstance(end, exp.Cast): 1700 target_type = end.to 1701 else: 1702 target_type = None 1703 1704 if start and end and target_type and target_type.is_type("date", "timestamp"): 1705 if isinstance(start, exp.Cast) and target_type is start.to: 1706 end = exp.cast(end, target_type) 1707 else: 1708 start = exp.cast(start, target_type) 1709 1710 return self.func("SEQUENCE", start, end, step)
1713def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1714 def _builder(args: t.List, dialect: Dialect) -> E: 1715 return expr_type( 1716 this=seq_get(args, 0), 1717 expression=seq_get(args, 1), 1718 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1719 parameters=seq_get(args, 3), 1720 ) 1721 1722 return _builder
1725def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1726 if isinstance(expression.this, exp.Explode): 1727 return self.sql( 1728 exp.Join( 1729 this=exp.Unnest( 1730 expressions=[expression.this.this], 1731 alias=expression.args.get("alias"), 1732 offset=isinstance(expression.this, exp.Posexplode), 1733 ), 1734 kind="cross", 1735 ) 1736 ) 1737 return self.lateral_sql(expression)
1744def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1745 args = [] 1746 for unit, value in expression.args.items(): 1747 if isinstance(value, exp.Kwarg): 1748 value = value.expression 1749 1750 args.append(f"{value} {unit}") 1751 1752 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"