sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses, to_bool 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[ 19 exp.DateAdd, 20 exp.DateDiff, 21 exp.DateSub, 22 exp.TsOrDsAdd, 23 exp.TsOrDsDiff, 24] 25DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 26JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 27 28 29if t.TYPE_CHECKING: 30 from sqlglot._typing import B, E, F 31 32 from sqlglot.optimizer.annotate_types import TypeAnnotator 33 34 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 35 36logger = logging.getLogger("sqlglot") 37 38UNESCAPED_SEQUENCES = { 39 "\\a": "\a", 40 "\\b": "\b", 41 "\\f": "\f", 42 "\\n": "\n", 43 "\\r": "\r", 44 "\\t": "\t", 45 "\\v": "\v", 46 "\\\\": "\\", 47} 48 49 50def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 51 return lambda self, e: self._annotate_with_type(e, data_type) 52 53 54class Dialects(str, Enum): 55 """Dialects supported by SQLGLot.""" 56 57 DIALECT = "" 58 59 ATHENA = "athena" 60 BIGQUERY = "bigquery" 61 CLICKHOUSE = "clickhouse" 62 DATABRICKS = "databricks" 63 DORIS = "doris" 64 DRILL = "drill" 65 DRUID = "druid" 66 DUCKDB = "duckdb" 67 HIVE = "hive" 68 MATERIALIZE = "materialize" 69 MYSQL = "mysql" 70 ORACLE = "oracle" 71 POSTGRES = "postgres" 72 PRESTO = "presto" 73 PRQL = "prql" 74 REDSHIFT = "redshift" 75 RISINGWAVE = "risingwave" 76 SNOWFLAKE = "snowflake" 77 SPARK = "spark" 78 SPARK2 = "spark2" 79 SQLITE = "sqlite" 80 STARROCKS = "starrocks" 81 TABLEAU = "tableau" 82 TERADATA = "teradata" 83 TRINO = "trino" 84 TSQL = "tsql" 85 86 87class NormalizationStrategy(str, AutoName): 88 """Specifies the strategy according to which identifiers should be normalized.""" 89 90 LOWERCASE = auto() 91 """Unquoted identifiers are lowercased.""" 92 93 UPPERCASE = auto() 94 """Unquoted identifiers are uppercased.""" 95 96 CASE_SENSITIVE = auto() 97 """Always case-sensitive, regardless of quotes.""" 98 99 CASE_INSENSITIVE = auto() 100 """Always case-insensitive, regardless of quotes.""" 101 102 103class _Dialect(type): 104 classes: t.Dict[str, t.Type[Dialect]] = {} 105 106 def __eq__(cls, other: t.Any) -> bool: 107 if cls is other: 108 return True 109 if isinstance(other, str): 110 return cls is cls.get(other) 111 if isinstance(other, Dialect): 112 return cls is type(other) 113 114 return False 115 116 def __hash__(cls) -> int: 117 return hash(cls.__name__.lower()) 118 119 @classmethod 120 def __getitem__(cls, key: str) -> t.Type[Dialect]: 121 return cls.classes[key] 122 123 @classmethod 124 def get( 125 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 126 ) -> t.Optional[t.Type[Dialect]]: 127 return cls.classes.get(key, default) 128 129 def __new__(cls, clsname, bases, attrs): 130 klass = super().__new__(cls, clsname, bases, attrs) 131 enum = Dialects.__members__.get(clsname.upper()) 132 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 133 134 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 135 klass.FORMAT_TRIE = ( 136 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 137 ) 138 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 139 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 140 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 141 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 142 143 klass.INVERSE_CREATABLE_KIND_MAPPING = { 144 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 145 } 146 147 base = seq_get(bases, 0) 148 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 149 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 150 base_parser = (getattr(base, "parser_class", Parser),) 151 base_generator = (getattr(base, "generator_class", Generator),) 152 153 klass.tokenizer_class = klass.__dict__.get( 154 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 155 ) 156 klass.jsonpath_tokenizer_class = klass.__dict__.get( 157 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 158 ) 159 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 160 klass.generator_class = klass.__dict__.get( 161 "Generator", type("Generator", base_generator, {}) 162 ) 163 164 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 165 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 166 klass.tokenizer_class._IDENTIFIERS.items() 167 )[0] 168 169 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 170 return next( 171 ( 172 (s, e) 173 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 174 if t == token_type 175 ), 176 (None, None), 177 ) 178 179 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 180 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 181 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 182 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 183 184 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 185 klass.UNESCAPED_SEQUENCES = { 186 **UNESCAPED_SEQUENCES, 187 **klass.UNESCAPED_SEQUENCES, 188 } 189 190 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 191 192 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 193 194 if enum not in ("", "bigquery"): 195 klass.generator_class.SELECT_KINDS = () 196 197 if enum not in ("", "athena", "presto", "trino"): 198 klass.generator_class.TRY_SUPPORTED = False 199 klass.generator_class.SUPPORTS_UESCAPE = False 200 201 if enum not in ("", "databricks", "hive", "spark", "spark2"): 202 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 203 for modifier in ("cluster", "distribute", "sort"): 204 modifier_transforms.pop(modifier, None) 205 206 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 207 208 if enum not in ("", "doris", "mysql"): 209 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 210 TokenType.STRAIGHT_JOIN, 211 } 212 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 213 TokenType.STRAIGHT_JOIN, 214 } 215 216 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 217 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 218 TokenType.ANTI, 219 TokenType.SEMI, 220 } 221 222 return klass 223 224 225class Dialect(metaclass=_Dialect): 226 INDEX_OFFSET = 0 227 """The base index offset for arrays.""" 228 229 WEEK_OFFSET = 0 230 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 231 232 UNNEST_COLUMN_ONLY = False 233 """Whether `UNNEST` table aliases are treated as column aliases.""" 234 235 ALIAS_POST_TABLESAMPLE = False 236 """Whether the table alias comes after tablesample.""" 237 238 TABLESAMPLE_SIZE_IS_PERCENT = False 239 """Whether a size in the table sample clause represents percentage.""" 240 241 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 242 """Specifies the strategy according to which identifiers should be normalized.""" 243 244 IDENTIFIERS_CAN_START_WITH_DIGIT = False 245 """Whether an unquoted identifier can start with a digit.""" 246 247 DPIPE_IS_STRING_CONCAT = True 248 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 249 250 STRICT_STRING_CONCAT = False 251 """Whether `CONCAT`'s arguments must be strings.""" 252 253 SUPPORTS_USER_DEFINED_TYPES = True 254 """Whether user-defined data types are supported.""" 255 256 SUPPORTS_SEMI_ANTI_JOIN = True 257 """Whether `SEMI` or `ANTI` joins are supported.""" 258 259 SUPPORTS_COLUMN_JOIN_MARKS = False 260 """Whether the old-style outer join (+) syntax is supported.""" 261 262 COPY_PARAMS_ARE_CSV = True 263 """Separator of COPY statement parameters.""" 264 265 NORMALIZE_FUNCTIONS: bool | str = "upper" 266 """ 267 Determines how function names are going to be normalized. 268 Possible values: 269 "upper" or True: Convert names to uppercase. 270 "lower": Convert names to lowercase. 271 False: Disables function name normalization. 272 """ 273 274 PRESERVE_ORIGINAL_NAMES: bool = False 275 """ 276 Whether the name of the function should be preserved inside the node's metadata, 277 can be useful for roundtripping deprecated vs new functions that share an AST node 278 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 279 """ 280 281 LOG_BASE_FIRST: t.Optional[bool] = True 282 """ 283 Whether the base comes first in the `LOG` function. 284 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 285 """ 286 287 NULL_ORDERING = "nulls_are_small" 288 """ 289 Default `NULL` ordering method to use if not explicitly set. 290 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 291 """ 292 293 TYPED_DIVISION = False 294 """ 295 Whether the behavior of `a / b` depends on the types of `a` and `b`. 296 False means `a / b` is always float division. 297 True means `a / b` is integer division if both `a` and `b` are integers. 298 """ 299 300 SAFE_DIVISION = False 301 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 302 303 CONCAT_COALESCE = False 304 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 305 306 HEX_LOWERCASE = False 307 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 308 309 DATE_FORMAT = "'%Y-%m-%d'" 310 DATEINT_FORMAT = "'%Y%m%d'" 311 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 312 313 TIME_MAPPING: t.Dict[str, str] = {} 314 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 315 316 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 317 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 318 FORMAT_MAPPING: t.Dict[str, str] = {} 319 """ 320 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 321 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 322 """ 323 324 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 325 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 326 327 PSEUDOCOLUMNS: t.Set[str] = set() 328 """ 329 Columns that are auto-generated by the engine corresponding to this dialect. 330 For example, such columns may be excluded from `SELECT *` queries. 331 """ 332 333 PREFER_CTE_ALIAS_COLUMN = False 334 """ 335 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 336 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 337 any projection aliases in the subquery. 338 339 For example, 340 WITH y(c) AS ( 341 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 342 ) SELECT c FROM y; 343 344 will be rewritten as 345 346 WITH y(c) AS ( 347 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 348 ) SELECT c FROM y; 349 """ 350 351 COPY_PARAMS_ARE_CSV = True 352 """ 353 Whether COPY statement parameters are separated by comma or whitespace 354 """ 355 356 FORCE_EARLY_ALIAS_REF_EXPANSION = False 357 """ 358 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 359 360 For example: 361 WITH data AS ( 362 SELECT 363 1 AS id, 364 2 AS my_id 365 ) 366 SELECT 367 id AS my_id 368 FROM 369 data 370 WHERE 371 my_id = 1 372 GROUP BY 373 my_id, 374 HAVING 375 my_id = 1 376 377 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 378 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 379 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 380 - Clickhouse, which will forward the alias across the query i.e it resolves 381 to "WHERE id = 1 GROUP BY id HAVING id = 1" 382 """ 383 384 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 385 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 386 387 SUPPORTS_ORDER_BY_ALL = False 388 """ 389 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 390 """ 391 392 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 393 """ 394 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 395 as the former is of type INT[] vs the latter which is SUPER 396 """ 397 398 SUPPORTS_FIXED_SIZE_ARRAYS = False 399 """ 400 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 401 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 402 be interpreted as a subscript/index operator. 403 """ 404 405 STRICT_JSON_PATH_SYNTAX = True 406 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 407 408 ON_CONDITION_EMPTY_BEFORE_ERROR = True 409 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 410 411 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 412 """Whether ArrayAgg needs to filter NULL values.""" 413 414 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 415 """ 416 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 417 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 418 is cast to x's type to match it instead. 419 """ 420 421 SUPPORTS_VALUES_DEFAULT = True 422 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 423 424 NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = False 425 """Whether number literals can include underscores for better readability""" 426 427 REGEXP_EXTRACT_DEFAULT_GROUP = 0 428 """The default value for the capturing group.""" 429 430 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 431 exp.Except: True, 432 exp.Intersect: True, 433 exp.Union: True, 434 } 435 """ 436 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 437 must be explicitly specified. 438 """ 439 440 CREATABLE_KIND_MAPPING: dict[str, str] = {} 441 """ 442 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 443 equivalent of CREATE SCHEMA is CREATE DATABASE. 444 """ 445 446 # --- Autofilled --- 447 448 tokenizer_class = Tokenizer 449 jsonpath_tokenizer_class = JSONPathTokenizer 450 parser_class = Parser 451 generator_class = Generator 452 453 # A trie of the time_mapping keys 454 TIME_TRIE: t.Dict = {} 455 FORMAT_TRIE: t.Dict = {} 456 457 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 458 INVERSE_TIME_TRIE: t.Dict = {} 459 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 460 INVERSE_FORMAT_TRIE: t.Dict = {} 461 462 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 463 464 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 465 466 # Delimiters for string literals and identifiers 467 QUOTE_START = "'" 468 QUOTE_END = "'" 469 IDENTIFIER_START = '"' 470 IDENTIFIER_END = '"' 471 472 # Delimiters for bit, hex, byte and unicode literals 473 BIT_START: t.Optional[str] = None 474 BIT_END: t.Optional[str] = None 475 HEX_START: t.Optional[str] = None 476 HEX_END: t.Optional[str] = None 477 BYTE_START: t.Optional[str] = None 478 BYTE_END: t.Optional[str] = None 479 UNICODE_START: t.Optional[str] = None 480 UNICODE_END: t.Optional[str] = None 481 482 DATE_PART_MAPPING = { 483 "Y": "YEAR", 484 "YY": "YEAR", 485 "YYY": "YEAR", 486 "YYYY": "YEAR", 487 "YR": "YEAR", 488 "YEARS": "YEAR", 489 "YRS": "YEAR", 490 "MM": "MONTH", 491 "MON": "MONTH", 492 "MONS": "MONTH", 493 "MONTHS": "MONTH", 494 "D": "DAY", 495 "DD": "DAY", 496 "DAYS": "DAY", 497 "DAYOFMONTH": "DAY", 498 "DAY OF WEEK": "DAYOFWEEK", 499 "WEEKDAY": "DAYOFWEEK", 500 "DOW": "DAYOFWEEK", 501 "DW": "DAYOFWEEK", 502 "WEEKDAY_ISO": "DAYOFWEEKISO", 503 "DOW_ISO": "DAYOFWEEKISO", 504 "DW_ISO": "DAYOFWEEKISO", 505 "DAY OF YEAR": "DAYOFYEAR", 506 "DOY": "DAYOFYEAR", 507 "DY": "DAYOFYEAR", 508 "W": "WEEK", 509 "WK": "WEEK", 510 "WEEKOFYEAR": "WEEK", 511 "WOY": "WEEK", 512 "WY": "WEEK", 513 "WEEK_ISO": "WEEKISO", 514 "WEEKOFYEARISO": "WEEKISO", 515 "WEEKOFYEAR_ISO": "WEEKISO", 516 "Q": "QUARTER", 517 "QTR": "QUARTER", 518 "QTRS": "QUARTER", 519 "QUARTERS": "QUARTER", 520 "H": "HOUR", 521 "HH": "HOUR", 522 "HR": "HOUR", 523 "HOURS": "HOUR", 524 "HRS": "HOUR", 525 "M": "MINUTE", 526 "MI": "MINUTE", 527 "MIN": "MINUTE", 528 "MINUTES": "MINUTE", 529 "MINS": "MINUTE", 530 "S": "SECOND", 531 "SEC": "SECOND", 532 "SECONDS": "SECOND", 533 "SECS": "SECOND", 534 "MS": "MILLISECOND", 535 "MSEC": "MILLISECOND", 536 "MSECS": "MILLISECOND", 537 "MSECOND": "MILLISECOND", 538 "MSECONDS": "MILLISECOND", 539 "MILLISEC": "MILLISECOND", 540 "MILLISECS": "MILLISECOND", 541 "MILLISECON": "MILLISECOND", 542 "MILLISECONDS": "MILLISECOND", 543 "US": "MICROSECOND", 544 "USEC": "MICROSECOND", 545 "USECS": "MICROSECOND", 546 "MICROSEC": "MICROSECOND", 547 "MICROSECS": "MICROSECOND", 548 "USECOND": "MICROSECOND", 549 "USECONDS": "MICROSECOND", 550 "MICROSECONDS": "MICROSECOND", 551 "NS": "NANOSECOND", 552 "NSEC": "NANOSECOND", 553 "NANOSEC": "NANOSECOND", 554 "NSECOND": "NANOSECOND", 555 "NSECONDS": "NANOSECOND", 556 "NANOSECS": "NANOSECOND", 557 "EPOCH_SECOND": "EPOCH", 558 "EPOCH_SECONDS": "EPOCH", 559 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 560 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 561 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 562 "TZH": "TIMEZONE_HOUR", 563 "TZM": "TIMEZONE_MINUTE", 564 "DEC": "DECADE", 565 "DECS": "DECADE", 566 "DECADES": "DECADE", 567 "MIL": "MILLENIUM", 568 "MILS": "MILLENIUM", 569 "MILLENIA": "MILLENIUM", 570 "C": "CENTURY", 571 "CENT": "CENTURY", 572 "CENTS": "CENTURY", 573 "CENTURIES": "CENTURY", 574 } 575 576 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 577 exp.DataType.Type.BIGINT: { 578 exp.ApproxDistinct, 579 exp.ArraySize, 580 exp.Length, 581 }, 582 exp.DataType.Type.BOOLEAN: { 583 exp.Between, 584 exp.Boolean, 585 exp.In, 586 exp.RegexpLike, 587 }, 588 exp.DataType.Type.DATE: { 589 exp.CurrentDate, 590 exp.Date, 591 exp.DateFromParts, 592 exp.DateStrToDate, 593 exp.DiToDate, 594 exp.StrToDate, 595 exp.TimeStrToDate, 596 exp.TsOrDsToDate, 597 }, 598 exp.DataType.Type.DATETIME: { 599 exp.CurrentDatetime, 600 exp.Datetime, 601 exp.DatetimeAdd, 602 exp.DatetimeSub, 603 }, 604 exp.DataType.Type.DOUBLE: { 605 exp.ApproxQuantile, 606 exp.Avg, 607 exp.Exp, 608 exp.Ln, 609 exp.Log, 610 exp.Pow, 611 exp.Quantile, 612 exp.Round, 613 exp.SafeDivide, 614 exp.Sqrt, 615 exp.Stddev, 616 exp.StddevPop, 617 exp.StddevSamp, 618 exp.ToDouble, 619 exp.Variance, 620 exp.VariancePop, 621 }, 622 exp.DataType.Type.INT: { 623 exp.Ceil, 624 exp.DatetimeDiff, 625 exp.DateDiff, 626 exp.TimestampDiff, 627 exp.TimeDiff, 628 exp.DateToDi, 629 exp.Levenshtein, 630 exp.Sign, 631 exp.StrPosition, 632 exp.TsOrDiToDi, 633 }, 634 exp.DataType.Type.JSON: { 635 exp.ParseJSON, 636 }, 637 exp.DataType.Type.TIME: { 638 exp.Time, 639 }, 640 exp.DataType.Type.TIMESTAMP: { 641 exp.CurrentTime, 642 exp.CurrentTimestamp, 643 exp.StrToTime, 644 exp.TimeAdd, 645 exp.TimeStrToTime, 646 exp.TimeSub, 647 exp.TimestampAdd, 648 exp.TimestampSub, 649 exp.UnixToTime, 650 }, 651 exp.DataType.Type.TINYINT: { 652 exp.Day, 653 exp.Month, 654 exp.Week, 655 exp.Year, 656 exp.Quarter, 657 }, 658 exp.DataType.Type.VARCHAR: { 659 exp.ArrayConcat, 660 exp.Concat, 661 exp.ConcatWs, 662 exp.DateToDateStr, 663 exp.GroupConcat, 664 exp.Initcap, 665 exp.Lower, 666 exp.Substring, 667 exp.String, 668 exp.TimeToStr, 669 exp.TimeToTimeStr, 670 exp.Trim, 671 exp.TsOrDsToDateStr, 672 exp.UnixToStr, 673 exp.UnixToTimeStr, 674 exp.Upper, 675 }, 676 } 677 678 ANNOTATORS: AnnotatorsType = { 679 **{ 680 expr_type: lambda self, e: self._annotate_unary(e) 681 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 682 }, 683 **{ 684 expr_type: lambda self, e: self._annotate_binary(e) 685 for expr_type in subclasses(exp.__name__, exp.Binary) 686 }, 687 **{ 688 expr_type: _annotate_with_type_lambda(data_type) 689 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 690 for expr_type in expressions 691 }, 692 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 693 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 694 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 695 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 696 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 697 exp.Bracket: lambda self, e: self._annotate_bracket(e), 698 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 699 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 700 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 701 exp.Count: lambda self, e: self._annotate_with_type( 702 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 703 ), 704 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 705 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 706 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 707 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 708 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 709 exp.Div: lambda self, e: self._annotate_div(e), 710 exp.Dot: lambda self, e: self._annotate_dot(e), 711 exp.Explode: lambda self, e: self._annotate_explode(e), 712 exp.Extract: lambda self, e: self._annotate_extract(e), 713 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 714 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 715 e, exp.DataType.build("ARRAY<DATE>") 716 ), 717 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 718 e, exp.DataType.build("ARRAY<TIMESTAMP>") 719 ), 720 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 721 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 722 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 723 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 724 exp.Literal: lambda self, e: self._annotate_literal(e), 725 exp.Map: lambda self, e: self._annotate_map(e), 726 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 727 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 728 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 729 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 730 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 731 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 732 exp.Struct: lambda self, e: self._annotate_struct(e), 733 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 734 exp.Timestamp: lambda self, e: self._annotate_with_type( 735 e, 736 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 737 ), 738 exp.ToMap: lambda self, e: self._annotate_to_map(e), 739 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 740 exp.Unnest: lambda self, e: self._annotate_unnest(e), 741 exp.VarMap: lambda self, e: self._annotate_map(e), 742 } 743 744 @classmethod 745 def get_or_raise(cls, dialect: DialectType) -> Dialect: 746 """ 747 Look up a dialect in the global dialect registry and return it if it exists. 748 749 Args: 750 dialect: The target dialect. If this is a string, it can be optionally followed by 751 additional key-value pairs that are separated by commas and are used to specify 752 dialect settings, such as whether the dialect's identifiers are case-sensitive. 753 754 Example: 755 >>> dialect = dialect_class = get_or_raise("duckdb") 756 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 757 758 Returns: 759 The corresponding Dialect instance. 760 """ 761 762 if not dialect: 763 return cls() 764 if isinstance(dialect, _Dialect): 765 return dialect() 766 if isinstance(dialect, Dialect): 767 return dialect 768 if isinstance(dialect, str): 769 try: 770 dialect_name, *kv_strings = dialect.split(",") 771 kv_pairs = (kv.split("=") for kv in kv_strings) 772 kwargs = {} 773 for pair in kv_pairs: 774 key = pair[0].strip() 775 value: t.Union[bool | str | None] = None 776 777 if len(pair) == 1: 778 # Default initialize standalone settings to True 779 value = True 780 elif len(pair) == 2: 781 value = pair[1].strip() 782 783 kwargs[key] = to_bool(value) 784 785 except ValueError: 786 raise ValueError( 787 f"Invalid dialect format: '{dialect}'. " 788 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 789 ) 790 791 result = cls.get(dialect_name.strip()) 792 if not result: 793 from difflib import get_close_matches 794 795 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 796 if similar: 797 similar = f" Did you mean {similar}?" 798 799 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 800 801 return result(**kwargs) 802 803 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 804 805 @classmethod 806 def format_time( 807 cls, expression: t.Optional[str | exp.Expression] 808 ) -> t.Optional[exp.Expression]: 809 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 810 if isinstance(expression, str): 811 return exp.Literal.string( 812 # the time formats are quoted 813 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 814 ) 815 816 if expression and expression.is_string: 817 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 818 819 return expression 820 821 def __init__(self, **kwargs) -> None: 822 normalization_strategy = kwargs.pop("normalization_strategy", None) 823 824 if normalization_strategy is None: 825 self.normalization_strategy = self.NORMALIZATION_STRATEGY 826 else: 827 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 828 829 self.settings = kwargs 830 831 def __eq__(self, other: t.Any) -> bool: 832 # Does not currently take dialect state into account 833 return type(self) == other 834 835 def __hash__(self) -> int: 836 # Does not currently take dialect state into account 837 return hash(type(self)) 838 839 def normalize_identifier(self, expression: E) -> E: 840 """ 841 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 842 843 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 844 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 845 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 846 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 847 848 There are also dialects like Spark, which are case-insensitive even when quotes are 849 present, and dialects like MySQL, whose resolution rules match those employed by the 850 underlying operating system, for example they may always be case-sensitive in Linux. 851 852 Finally, the normalization behavior of some engines can even be controlled through flags, 853 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 854 855 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 856 that it can analyze queries in the optimizer and successfully capture their semantics. 857 """ 858 if ( 859 isinstance(expression, exp.Identifier) 860 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 861 and ( 862 not expression.quoted 863 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 864 ) 865 ): 866 expression.set( 867 "this", 868 ( 869 expression.this.upper() 870 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 871 else expression.this.lower() 872 ), 873 ) 874 875 return expression 876 877 def case_sensitive(self, text: str) -> bool: 878 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 879 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 880 return False 881 882 unsafe = ( 883 str.islower 884 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 885 else str.isupper 886 ) 887 return any(unsafe(char) for char in text) 888 889 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 890 """Checks if text can be identified given an identify option. 891 892 Args: 893 text: The text to check. 894 identify: 895 `"always"` or `True`: Always returns `True`. 896 `"safe"`: Only returns `True` if the identifier is case-insensitive. 897 898 Returns: 899 Whether the given text can be identified. 900 """ 901 if identify is True or identify == "always": 902 return True 903 904 if identify == "safe": 905 return not self.case_sensitive(text) 906 907 return False 908 909 def quote_identifier(self, expression: E, identify: bool = True) -> E: 910 """ 911 Adds quotes to a given identifier. 912 913 Args: 914 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 915 identify: If set to `False`, the quotes will only be added if the identifier is deemed 916 "unsafe", with respect to its characters and this dialect's normalization strategy. 917 """ 918 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 919 name = expression.this 920 expression.set( 921 "quoted", 922 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 923 ) 924 925 return expression 926 927 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 928 if isinstance(path, exp.Literal): 929 path_text = path.name 930 if path.is_number: 931 path_text = f"[{path_text}]" 932 try: 933 return parse_json_path(path_text, self) 934 except ParseError as e: 935 if self.STRICT_JSON_PATH_SYNTAX: 936 logger.warning(f"Invalid JSON path syntax. {str(e)}") 937 938 return path 939 940 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 941 return self.parser(**opts).parse(self.tokenize(sql), sql) 942 943 def parse_into( 944 self, expression_type: exp.IntoType, sql: str, **opts 945 ) -> t.List[t.Optional[exp.Expression]]: 946 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 947 948 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 949 return self.generator(**opts).generate(expression, copy=copy) 950 951 def transpile(self, sql: str, **opts) -> t.List[str]: 952 return [ 953 self.generate(expression, copy=False, **opts) if expression else "" 954 for expression in self.parse(sql) 955 ] 956 957 def tokenize(self, sql: str) -> t.List[Token]: 958 return self.tokenizer.tokenize(sql) 959 960 @property 961 def tokenizer(self) -> Tokenizer: 962 return self.tokenizer_class(dialect=self) 963 964 @property 965 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 966 return self.jsonpath_tokenizer_class(dialect=self) 967 968 def parser(self, **opts) -> Parser: 969 return self.parser_class(dialect=self, **opts) 970 971 def generator(self, **opts) -> Generator: 972 return self.generator_class(dialect=self, **opts) 973 974 975DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 976 977 978def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 979 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 980 981 982@unsupported_args("accuracy") 983def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 984 return self.func("APPROX_COUNT_DISTINCT", expression.this) 985 986 987def if_sql( 988 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 989) -> t.Callable[[Generator, exp.If], str]: 990 def _if_sql(self: Generator, expression: exp.If) -> str: 991 return self.func( 992 name, 993 expression.this, 994 expression.args.get("true"), 995 expression.args.get("false") or false_value, 996 ) 997 998 return _if_sql 999 1000 1001def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1002 this = expression.this 1003 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1004 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1005 1006 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 1007 1008 1009def inline_array_sql(self: Generator, expression: exp.Array) -> str: 1010 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 1011 1012 1013def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 1014 elem = seq_get(expression.expressions, 0) 1015 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 1016 return self.func("ARRAY", elem) 1017 return inline_array_sql(self, expression) 1018 1019 1020def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1021 return self.like_sql( 1022 exp.Like( 1023 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1024 ) 1025 ) 1026 1027 1028def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1029 zone = self.sql(expression, "this") 1030 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1031 1032 1033def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1034 if expression.args.get("recursive"): 1035 self.unsupported("Recursive CTEs are unsupported") 1036 expression.args["recursive"] = False 1037 return self.with_sql(expression) 1038 1039 1040def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1041 self.unsupported("TABLESAMPLE unsupported") 1042 return self.sql(expression.this) 1043 1044 1045def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1046 self.unsupported("PIVOT unsupported") 1047 return "" 1048 1049 1050def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1051 return self.cast_sql(expression) 1052 1053 1054def no_comment_column_constraint_sql( 1055 self: Generator, expression: exp.CommentColumnConstraint 1056) -> str: 1057 self.unsupported("CommentColumnConstraint unsupported") 1058 return "" 1059 1060 1061def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1062 self.unsupported("MAP_FROM_ENTRIES unsupported") 1063 return "" 1064 1065 1066def property_sql(self: Generator, expression: exp.Property) -> str: 1067 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1068 1069 1070def strposition_sql( 1071 self: Generator, 1072 expression: exp.StrPosition, 1073 func_name: str = "STRPOS", 1074 supports_position: bool = False, 1075 supports_occurrence: bool = False, 1076 use_ansi_position: bool = True, 1077) -> str: 1078 string = expression.this 1079 substr = expression.args.get("substr") 1080 position = expression.args.get("position") 1081 occurrence = expression.args.get("occurrence") 1082 zero = exp.Literal.number(0) 1083 one = exp.Literal.number(1) 1084 1085 if supports_occurrence and occurrence and supports_position and not position: 1086 position = one 1087 1088 transpile_position = position and not supports_position 1089 if transpile_position: 1090 string = exp.Substring(this=string, start=position) 1091 1092 if func_name == "POSITION" and use_ansi_position: 1093 func = exp.Anonymous(this=func_name, expressions=[exp.In(this=substr, field=string)]) 1094 else: 1095 args = [substr, string] if func_name in ("LOCATE", "CHARINDEX") else [string, substr] 1096 if supports_position: 1097 args.append(position) 1098 if occurrence: 1099 if supports_occurrence: 1100 args.append(occurrence) 1101 else: 1102 self.unsupported(f"{func_name} does not support the occurrence parameter.") 1103 func = exp.Anonymous(this=func_name, expressions=args) 1104 1105 if transpile_position: 1106 func_with_offset = exp.Sub(this=func + position, expression=one) 1107 func_wrapped = exp.If(this=func.eq(zero), true=zero, false=func_with_offset) 1108 return self.sql(func_wrapped) 1109 1110 return self.sql(func) 1111 1112 1113def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1114 return ( 1115 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1116 ) 1117 1118 1119def var_map_sql( 1120 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1121) -> str: 1122 keys = expression.args["keys"] 1123 values = expression.args["values"] 1124 1125 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1126 self.unsupported("Cannot convert array columns into map.") 1127 return self.func(map_func_name, keys, values) 1128 1129 args = [] 1130 for key, value in zip(keys.expressions, values.expressions): 1131 args.append(self.sql(key)) 1132 args.append(self.sql(value)) 1133 1134 return self.func(map_func_name, *args) 1135 1136 1137def build_formatted_time( 1138 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1139) -> t.Callable[[t.List], E]: 1140 """Helper used for time expressions. 1141 1142 Args: 1143 exp_class: the expression class to instantiate. 1144 dialect: target sql dialect. 1145 default: the default format, True being time. 1146 1147 Returns: 1148 A callable that can be used to return the appropriately formatted time expression. 1149 """ 1150 1151 def _builder(args: t.List): 1152 return exp_class( 1153 this=seq_get(args, 0), 1154 format=Dialect[dialect].format_time( 1155 seq_get(args, 1) 1156 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1157 ), 1158 ) 1159 1160 return _builder 1161 1162 1163def time_format( 1164 dialect: DialectType = None, 1165) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1166 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1167 """ 1168 Returns the time format for a given expression, unless it's equivalent 1169 to the default time format of the dialect of interest. 1170 """ 1171 time_format = self.format_time(expression) 1172 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1173 1174 return _time_format 1175 1176 1177def build_date_delta( 1178 exp_class: t.Type[E], 1179 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1180 default_unit: t.Optional[str] = "DAY", 1181) -> t.Callable[[t.List], E]: 1182 def _builder(args: t.List) -> E: 1183 unit_based = len(args) == 3 1184 this = args[2] if unit_based else seq_get(args, 0) 1185 unit = None 1186 if unit_based or default_unit: 1187 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1188 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1189 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1190 1191 return _builder 1192 1193 1194def build_date_delta_with_interval( 1195 expression_class: t.Type[E], 1196) -> t.Callable[[t.List], t.Optional[E]]: 1197 def _builder(args: t.List) -> t.Optional[E]: 1198 if len(args) < 2: 1199 return None 1200 1201 interval = args[1] 1202 1203 if not isinstance(interval, exp.Interval): 1204 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1205 1206 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1207 1208 return _builder 1209 1210 1211def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1212 unit = seq_get(args, 0) 1213 this = seq_get(args, 1) 1214 1215 if isinstance(this, exp.Cast) and this.is_type("date"): 1216 return exp.DateTrunc(unit=unit, this=this) 1217 return exp.TimestampTrunc(this=this, unit=unit) 1218 1219 1220def date_add_interval_sql( 1221 data_type: str, kind: str 1222) -> t.Callable[[Generator, exp.Expression], str]: 1223 def func(self: Generator, expression: exp.Expression) -> str: 1224 this = self.sql(expression, "this") 1225 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1226 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1227 1228 return func 1229 1230 1231def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1232 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1233 args = [unit_to_str(expression), expression.this] 1234 if zone: 1235 args.append(expression.args.get("zone")) 1236 return self.func("DATE_TRUNC", *args) 1237 1238 return _timestamptrunc_sql 1239 1240 1241def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1242 zone = expression.args.get("zone") 1243 if not zone: 1244 from sqlglot.optimizer.annotate_types import annotate_types 1245 1246 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1247 return self.sql(exp.cast(expression.this, target_type)) 1248 if zone.name.lower() in TIMEZONES: 1249 return self.sql( 1250 exp.AtTimeZone( 1251 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1252 zone=zone, 1253 ) 1254 ) 1255 return self.func("TIMESTAMP", expression.this, zone) 1256 1257 1258def no_time_sql(self: Generator, expression: exp.Time) -> str: 1259 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1260 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1261 expr = exp.cast( 1262 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1263 ) 1264 return self.sql(expr) 1265 1266 1267def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1268 this = expression.this 1269 expr = expression.expression 1270 1271 if expr.name.lower() in TIMEZONES: 1272 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1273 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1274 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1275 return self.sql(this) 1276 1277 this = exp.cast(this, exp.DataType.Type.DATE) 1278 expr = exp.cast(expr, exp.DataType.Type.TIME) 1279 1280 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1281 1282 1283def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1284 return self.sql( 1285 exp.Substring( 1286 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1287 ) 1288 ) 1289 1290 1291def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1292 return self.sql( 1293 exp.Substring( 1294 this=expression.this, 1295 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1296 ) 1297 ) 1298 1299 1300def timestrtotime_sql( 1301 self: Generator, 1302 expression: exp.TimeStrToTime, 1303 include_precision: bool = False, 1304) -> str: 1305 datatype = exp.DataType.build( 1306 exp.DataType.Type.TIMESTAMPTZ 1307 if expression.args.get("zone") 1308 else exp.DataType.Type.TIMESTAMP 1309 ) 1310 1311 if isinstance(expression.this, exp.Literal) and include_precision: 1312 precision = subsecond_precision(expression.this.name) 1313 if precision > 0: 1314 datatype = exp.DataType.build( 1315 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1316 ) 1317 1318 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1319 1320 1321def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1322 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1323 1324 1325# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1326def encode_decode_sql( 1327 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1328) -> str: 1329 charset = expression.args.get("charset") 1330 if charset and charset.name.lower() != "utf-8": 1331 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1332 1333 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1334 1335 1336def min_or_least(self: Generator, expression: exp.Min) -> str: 1337 name = "LEAST" if expression.expressions else "MIN" 1338 return rename_func(name)(self, expression) 1339 1340 1341def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1342 name = "GREATEST" if expression.expressions else "MAX" 1343 return rename_func(name)(self, expression) 1344 1345 1346def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1347 cond = expression.this 1348 1349 if isinstance(expression.this, exp.Distinct): 1350 cond = expression.this.expressions[0] 1351 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1352 1353 return self.func("sum", exp.func("if", cond, 1, 0)) 1354 1355 1356def trim_sql(self: Generator, expression: exp.Trim) -> str: 1357 target = self.sql(expression, "this") 1358 trim_type = self.sql(expression, "position") 1359 remove_chars = self.sql(expression, "expression") 1360 collation = self.sql(expression, "collation") 1361 1362 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1363 if not remove_chars: 1364 return self.trim_sql(expression) 1365 1366 trim_type = f"{trim_type} " if trim_type else "" 1367 remove_chars = f"{remove_chars} " if remove_chars else "" 1368 from_part = "FROM " if trim_type or remove_chars else "" 1369 collation = f" COLLATE {collation}" if collation else "" 1370 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1371 1372 1373def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1374 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1375 1376 1377def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1378 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1379 1380 1381def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1382 delim, *rest_args = expression.expressions 1383 return self.sql( 1384 reduce( 1385 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1386 rest_args, 1387 ) 1388 ) 1389 1390 1391@unsupported_args("position", "occurrence", "parameters") 1392def regexp_extract_sql( 1393 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1394) -> str: 1395 group = expression.args.get("group") 1396 1397 # Do not render group if it's the default value for this dialect 1398 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1399 group = None 1400 1401 return self.func(expression.sql_name(), expression.this, expression.expression, group) 1402 1403 1404@unsupported_args("position", "occurrence", "modifiers") 1405def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1406 return self.func( 1407 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1408 ) 1409 1410 1411def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1412 names = [] 1413 for agg in aggregations: 1414 if isinstance(agg, exp.Alias): 1415 names.append(agg.alias) 1416 else: 1417 """ 1418 This case corresponds to aggregations without aliases being used as suffixes 1419 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1420 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1421 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1422 """ 1423 agg_all_unquoted = agg.transform( 1424 lambda node: ( 1425 exp.Identifier(this=node.name, quoted=False) 1426 if isinstance(node, exp.Identifier) 1427 else node 1428 ) 1429 ) 1430 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1431 1432 return names 1433 1434 1435def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1436 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1437 1438 1439# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1440def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1441 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1442 1443 1444def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1445 return self.func("MAX", expression.this) 1446 1447 1448def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1449 a = self.sql(expression.left) 1450 b = self.sql(expression.right) 1451 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1452 1453 1454def is_parse_json(expression: exp.Expression) -> bool: 1455 return isinstance(expression, exp.ParseJSON) or ( 1456 isinstance(expression, exp.Cast) and expression.is_type("json") 1457 ) 1458 1459 1460def isnull_to_is_null(args: t.List) -> exp.Expression: 1461 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1462 1463 1464def generatedasidentitycolumnconstraint_sql( 1465 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1466) -> str: 1467 start = self.sql(expression, "start") or "1" 1468 increment = self.sql(expression, "increment") or "1" 1469 return f"IDENTITY({start}, {increment})" 1470 1471 1472def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1473 @unsupported_args("count") 1474 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1475 return self.func(name, expression.this, expression.expression) 1476 1477 return _arg_max_or_min_sql 1478 1479 1480def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1481 this = expression.this.copy() 1482 1483 return_type = expression.return_type 1484 if return_type.is_type(exp.DataType.Type.DATE): 1485 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1486 # can truncate timestamp strings, because some dialects can't cast them to DATE 1487 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1488 1489 expression.this.replace(exp.cast(this, return_type)) 1490 return expression 1491 1492 1493def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1494 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1495 if cast and isinstance(expression, exp.TsOrDsAdd): 1496 expression = ts_or_ds_add_cast(expression) 1497 1498 return self.func( 1499 name, 1500 unit_to_var(expression), 1501 expression.expression, 1502 expression.this, 1503 ) 1504 1505 return _delta_sql 1506 1507 1508def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1509 unit = expression.args.get("unit") 1510 1511 if isinstance(unit, exp.Placeholder): 1512 return unit 1513 if unit: 1514 return exp.Literal.string(unit.name) 1515 return exp.Literal.string(default) if default else None 1516 1517 1518def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1519 unit = expression.args.get("unit") 1520 1521 if isinstance(unit, (exp.Var, exp.Placeholder)): 1522 return unit 1523 return exp.Var(this=default) if default else None 1524 1525 1526@t.overload 1527def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1528 pass 1529 1530 1531@t.overload 1532def map_date_part( 1533 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1534) -> t.Optional[exp.Expression]: 1535 pass 1536 1537 1538def map_date_part(part, dialect: DialectType = Dialect): 1539 mapped = ( 1540 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1541 ) 1542 return exp.var(mapped) if mapped else part 1543 1544 1545def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1546 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1547 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1548 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1549 1550 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1551 1552 1553def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1554 """Remove table refs from columns in when statements.""" 1555 alias = expression.this.args.get("alias") 1556 1557 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1558 return self.dialect.normalize_identifier(identifier).name if identifier else None 1559 1560 targets = {normalize(expression.this.this)} 1561 1562 if alias: 1563 targets.add(normalize(alias.this)) 1564 1565 for when in expression.args["whens"].expressions: 1566 # only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED 1567 # they are still valid in the <condition>, the right hand side of each UPDATE and the VALUES part 1568 # (not the column list) of the INSERT 1569 then: exp.Insert | exp.Update | None = when.args.get("then") 1570 if then: 1571 if isinstance(then, exp.Update): 1572 for equals in then.find_all(exp.EQ): 1573 equal_lhs = equals.this 1574 if ( 1575 isinstance(equal_lhs, exp.Column) 1576 and normalize(equal_lhs.args.get("table")) in targets 1577 ): 1578 equal_lhs.replace(exp.column(equal_lhs.this)) 1579 if isinstance(then, exp.Insert): 1580 column_list = then.this 1581 if isinstance(column_list, exp.Tuple): 1582 for column in column_list.expressions: 1583 if normalize(column.args.get("table")) in targets: 1584 column.replace(exp.column(column.this)) 1585 1586 return self.merge_sql(expression) 1587 1588 1589def build_json_extract_path( 1590 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1591) -> t.Callable[[t.List], F]: 1592 def _builder(args: t.List) -> F: 1593 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1594 for arg in args[1:]: 1595 if not isinstance(arg, exp.Literal): 1596 # We use the fallback parser because we can't really transpile non-literals safely 1597 return expr_type.from_arg_list(args) 1598 1599 text = arg.name 1600 if is_int(text): 1601 index = int(text) 1602 segments.append( 1603 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1604 ) 1605 else: 1606 segments.append(exp.JSONPathKey(this=text)) 1607 1608 # This is done to avoid failing in the expression validator due to the arg count 1609 del args[2:] 1610 return expr_type( 1611 this=seq_get(args, 0), 1612 expression=exp.JSONPath(expressions=segments), 1613 only_json_types=arrow_req_json_type, 1614 ) 1615 1616 return _builder 1617 1618 1619def json_extract_segments( 1620 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1621) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1622 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1623 path = expression.expression 1624 if not isinstance(path, exp.JSONPath): 1625 return rename_func(name)(self, expression) 1626 1627 escape = path.args.get("escape") 1628 1629 segments = [] 1630 for segment in path.expressions: 1631 path = self.sql(segment) 1632 if path: 1633 if isinstance(segment, exp.JSONPathPart) and ( 1634 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1635 ): 1636 if escape: 1637 path = self.escape_str(path) 1638 1639 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1640 1641 segments.append(path) 1642 1643 if op: 1644 return f" {op} ".join([self.sql(expression.this), *segments]) 1645 return self.func(name, expression.this, *segments) 1646 1647 return _json_extract_segments 1648 1649 1650def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1651 if isinstance(expression.this, exp.JSONPathWildcard): 1652 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1653 1654 return expression.name 1655 1656 1657def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1658 cond = expression.expression 1659 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1660 alias = cond.expressions[0] 1661 cond = cond.this 1662 elif isinstance(cond, exp.Predicate): 1663 alias = "_u" 1664 else: 1665 self.unsupported("Unsupported filter condition") 1666 return "" 1667 1668 unnest = exp.Unnest(expressions=[expression.this]) 1669 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1670 return self.sql(exp.Array(expressions=[filtered])) 1671 1672 1673def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1674 return self.func( 1675 "TO_NUMBER", 1676 expression.this, 1677 expression.args.get("format"), 1678 expression.args.get("nlsparam"), 1679 ) 1680 1681 1682def build_default_decimal_type( 1683 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1684) -> t.Callable[[exp.DataType], exp.DataType]: 1685 def _builder(dtype: exp.DataType) -> exp.DataType: 1686 if dtype.expressions or precision is None: 1687 return dtype 1688 1689 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1690 return exp.DataType.build(f"DECIMAL({params})") 1691 1692 return _builder 1693 1694 1695def build_timestamp_from_parts(args: t.List) -> exp.Func: 1696 if len(args) == 2: 1697 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1698 # so we parse this into Anonymous for now instead of introducing complexity 1699 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1700 1701 return exp.TimestampFromParts.from_arg_list(args) 1702 1703 1704def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1705 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1706 1707 1708def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1709 start = expression.args.get("start") 1710 end = expression.args.get("end") 1711 step = expression.args.get("step") 1712 1713 if isinstance(start, exp.Cast): 1714 target_type = start.to 1715 elif isinstance(end, exp.Cast): 1716 target_type = end.to 1717 else: 1718 target_type = None 1719 1720 if start and end and target_type and target_type.is_type("date", "timestamp"): 1721 if isinstance(start, exp.Cast) and target_type is start.to: 1722 end = exp.cast(end, target_type) 1723 else: 1724 start = exp.cast(start, target_type) 1725 1726 return self.func("SEQUENCE", start, end, step) 1727 1728 1729def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1730 def _builder(args: t.List, dialect: Dialect) -> E: 1731 return expr_type( 1732 this=seq_get(args, 0), 1733 expression=seq_get(args, 1), 1734 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1735 parameters=seq_get(args, 3), 1736 ) 1737 1738 return _builder 1739 1740 1741def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1742 if isinstance(expression.this, exp.Explode): 1743 return self.sql( 1744 exp.Join( 1745 this=exp.Unnest( 1746 expressions=[expression.this.this], 1747 alias=expression.args.get("alias"), 1748 offset=isinstance(expression.this, exp.Posexplode), 1749 ), 1750 kind="cross", 1751 ) 1752 ) 1753 return self.lateral_sql(expression) 1754 1755 1756def timestampdiff_sql(self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff) -> str: 1757 return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) 1758 1759 1760def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1761 args = [] 1762 for unit, value in expression.args.items(): 1763 if isinstance(value, exp.Kwarg): 1764 value = value.expression 1765 1766 args.append(f"{value} {unit}") 1767 1768 return f"INTERVAL '{self.format_args(*args, sep=sep)}'" 1769 1770 1771def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str: 1772 length_func = "LENGTH" if expression.args.get("binary") else "CHAR_LENGTH" 1773 return self.func(length_func, expression.this)
55class Dialects(str, Enum): 56 """Dialects supported by SQLGLot.""" 57 58 DIALECT = "" 59 60 ATHENA = "athena" 61 BIGQUERY = "bigquery" 62 CLICKHOUSE = "clickhouse" 63 DATABRICKS = "databricks" 64 DORIS = "doris" 65 DRILL = "drill" 66 DRUID = "druid" 67 DUCKDB = "duckdb" 68 HIVE = "hive" 69 MATERIALIZE = "materialize" 70 MYSQL = "mysql" 71 ORACLE = "oracle" 72 POSTGRES = "postgres" 73 PRESTO = "presto" 74 PRQL = "prql" 75 REDSHIFT = "redshift" 76 RISINGWAVE = "risingwave" 77 SNOWFLAKE = "snowflake" 78 SPARK = "spark" 79 SPARK2 = "spark2" 80 SQLITE = "sqlite" 81 STARROCKS = "starrocks" 82 TABLEAU = "tableau" 83 TERADATA = "teradata" 84 TRINO = "trino" 85 TSQL = "tsql"
Dialects supported by SQLGLot.
88class NormalizationStrategy(str, AutoName): 89 """Specifies the strategy according to which identifiers should be normalized.""" 90 91 LOWERCASE = auto() 92 """Unquoted identifiers are lowercased.""" 93 94 UPPERCASE = auto() 95 """Unquoted identifiers are uppercased.""" 96 97 CASE_SENSITIVE = auto() 98 """Always case-sensitive, regardless of quotes.""" 99 100 CASE_INSENSITIVE = auto() 101 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
226class Dialect(metaclass=_Dialect): 227 INDEX_OFFSET = 0 228 """The base index offset for arrays.""" 229 230 WEEK_OFFSET = 0 231 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 232 233 UNNEST_COLUMN_ONLY = False 234 """Whether `UNNEST` table aliases are treated as column aliases.""" 235 236 ALIAS_POST_TABLESAMPLE = False 237 """Whether the table alias comes after tablesample.""" 238 239 TABLESAMPLE_SIZE_IS_PERCENT = False 240 """Whether a size in the table sample clause represents percentage.""" 241 242 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 243 """Specifies the strategy according to which identifiers should be normalized.""" 244 245 IDENTIFIERS_CAN_START_WITH_DIGIT = False 246 """Whether an unquoted identifier can start with a digit.""" 247 248 DPIPE_IS_STRING_CONCAT = True 249 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 250 251 STRICT_STRING_CONCAT = False 252 """Whether `CONCAT`'s arguments must be strings.""" 253 254 SUPPORTS_USER_DEFINED_TYPES = True 255 """Whether user-defined data types are supported.""" 256 257 SUPPORTS_SEMI_ANTI_JOIN = True 258 """Whether `SEMI` or `ANTI` joins are supported.""" 259 260 SUPPORTS_COLUMN_JOIN_MARKS = False 261 """Whether the old-style outer join (+) syntax is supported.""" 262 263 COPY_PARAMS_ARE_CSV = True 264 """Separator of COPY statement parameters.""" 265 266 NORMALIZE_FUNCTIONS: bool | str = "upper" 267 """ 268 Determines how function names are going to be normalized. 269 Possible values: 270 "upper" or True: Convert names to uppercase. 271 "lower": Convert names to lowercase. 272 False: Disables function name normalization. 273 """ 274 275 PRESERVE_ORIGINAL_NAMES: bool = False 276 """ 277 Whether the name of the function should be preserved inside the node's metadata, 278 can be useful for roundtripping deprecated vs new functions that share an AST node 279 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 280 """ 281 282 LOG_BASE_FIRST: t.Optional[bool] = True 283 """ 284 Whether the base comes first in the `LOG` function. 285 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 286 """ 287 288 NULL_ORDERING = "nulls_are_small" 289 """ 290 Default `NULL` ordering method to use if not explicitly set. 291 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 292 """ 293 294 TYPED_DIVISION = False 295 """ 296 Whether the behavior of `a / b` depends on the types of `a` and `b`. 297 False means `a / b` is always float division. 298 True means `a / b` is integer division if both `a` and `b` are integers. 299 """ 300 301 SAFE_DIVISION = False 302 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 303 304 CONCAT_COALESCE = False 305 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 306 307 HEX_LOWERCASE = False 308 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 309 310 DATE_FORMAT = "'%Y-%m-%d'" 311 DATEINT_FORMAT = "'%Y%m%d'" 312 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 313 314 TIME_MAPPING: t.Dict[str, str] = {} 315 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 316 317 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 318 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 319 FORMAT_MAPPING: t.Dict[str, str] = {} 320 """ 321 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 322 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 323 """ 324 325 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 326 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 327 328 PSEUDOCOLUMNS: t.Set[str] = set() 329 """ 330 Columns that are auto-generated by the engine corresponding to this dialect. 331 For example, such columns may be excluded from `SELECT *` queries. 332 """ 333 334 PREFER_CTE_ALIAS_COLUMN = False 335 """ 336 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 337 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 338 any projection aliases in the subquery. 339 340 For example, 341 WITH y(c) AS ( 342 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 343 ) SELECT c FROM y; 344 345 will be rewritten as 346 347 WITH y(c) AS ( 348 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 349 ) SELECT c FROM y; 350 """ 351 352 COPY_PARAMS_ARE_CSV = True 353 """ 354 Whether COPY statement parameters are separated by comma or whitespace 355 """ 356 357 FORCE_EARLY_ALIAS_REF_EXPANSION = False 358 """ 359 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 360 361 For example: 362 WITH data AS ( 363 SELECT 364 1 AS id, 365 2 AS my_id 366 ) 367 SELECT 368 id AS my_id 369 FROM 370 data 371 WHERE 372 my_id = 1 373 GROUP BY 374 my_id, 375 HAVING 376 my_id = 1 377 378 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 379 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 380 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 381 - Clickhouse, which will forward the alias across the query i.e it resolves 382 to "WHERE id = 1 GROUP BY id HAVING id = 1" 383 """ 384 385 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 386 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 387 388 SUPPORTS_ORDER_BY_ALL = False 389 """ 390 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 391 """ 392 393 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 394 """ 395 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 396 as the former is of type INT[] vs the latter which is SUPER 397 """ 398 399 SUPPORTS_FIXED_SIZE_ARRAYS = False 400 """ 401 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 402 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 403 be interpreted as a subscript/index operator. 404 """ 405 406 STRICT_JSON_PATH_SYNTAX = True 407 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 408 409 ON_CONDITION_EMPTY_BEFORE_ERROR = True 410 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 411 412 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 413 """Whether ArrayAgg needs to filter NULL values.""" 414 415 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 416 """ 417 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 418 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 419 is cast to x's type to match it instead. 420 """ 421 422 SUPPORTS_VALUES_DEFAULT = True 423 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 424 425 NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = False 426 """Whether number literals can include underscores for better readability""" 427 428 REGEXP_EXTRACT_DEFAULT_GROUP = 0 429 """The default value for the capturing group.""" 430 431 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 432 exp.Except: True, 433 exp.Intersect: True, 434 exp.Union: True, 435 } 436 """ 437 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 438 must be explicitly specified. 439 """ 440 441 CREATABLE_KIND_MAPPING: dict[str, str] = {} 442 """ 443 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 444 equivalent of CREATE SCHEMA is CREATE DATABASE. 445 """ 446 447 # --- Autofilled --- 448 449 tokenizer_class = Tokenizer 450 jsonpath_tokenizer_class = JSONPathTokenizer 451 parser_class = Parser 452 generator_class = Generator 453 454 # A trie of the time_mapping keys 455 TIME_TRIE: t.Dict = {} 456 FORMAT_TRIE: t.Dict = {} 457 458 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 459 INVERSE_TIME_TRIE: t.Dict = {} 460 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 461 INVERSE_FORMAT_TRIE: t.Dict = {} 462 463 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 464 465 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 466 467 # Delimiters for string literals and identifiers 468 QUOTE_START = "'" 469 QUOTE_END = "'" 470 IDENTIFIER_START = '"' 471 IDENTIFIER_END = '"' 472 473 # Delimiters for bit, hex, byte and unicode literals 474 BIT_START: t.Optional[str] = None 475 BIT_END: t.Optional[str] = None 476 HEX_START: t.Optional[str] = None 477 HEX_END: t.Optional[str] = None 478 BYTE_START: t.Optional[str] = None 479 BYTE_END: t.Optional[str] = None 480 UNICODE_START: t.Optional[str] = None 481 UNICODE_END: t.Optional[str] = None 482 483 DATE_PART_MAPPING = { 484 "Y": "YEAR", 485 "YY": "YEAR", 486 "YYY": "YEAR", 487 "YYYY": "YEAR", 488 "YR": "YEAR", 489 "YEARS": "YEAR", 490 "YRS": "YEAR", 491 "MM": "MONTH", 492 "MON": "MONTH", 493 "MONS": "MONTH", 494 "MONTHS": "MONTH", 495 "D": "DAY", 496 "DD": "DAY", 497 "DAYS": "DAY", 498 "DAYOFMONTH": "DAY", 499 "DAY OF WEEK": "DAYOFWEEK", 500 "WEEKDAY": "DAYOFWEEK", 501 "DOW": "DAYOFWEEK", 502 "DW": "DAYOFWEEK", 503 "WEEKDAY_ISO": "DAYOFWEEKISO", 504 "DOW_ISO": "DAYOFWEEKISO", 505 "DW_ISO": "DAYOFWEEKISO", 506 "DAY OF YEAR": "DAYOFYEAR", 507 "DOY": "DAYOFYEAR", 508 "DY": "DAYOFYEAR", 509 "W": "WEEK", 510 "WK": "WEEK", 511 "WEEKOFYEAR": "WEEK", 512 "WOY": "WEEK", 513 "WY": "WEEK", 514 "WEEK_ISO": "WEEKISO", 515 "WEEKOFYEARISO": "WEEKISO", 516 "WEEKOFYEAR_ISO": "WEEKISO", 517 "Q": "QUARTER", 518 "QTR": "QUARTER", 519 "QTRS": "QUARTER", 520 "QUARTERS": "QUARTER", 521 "H": "HOUR", 522 "HH": "HOUR", 523 "HR": "HOUR", 524 "HOURS": "HOUR", 525 "HRS": "HOUR", 526 "M": "MINUTE", 527 "MI": "MINUTE", 528 "MIN": "MINUTE", 529 "MINUTES": "MINUTE", 530 "MINS": "MINUTE", 531 "S": "SECOND", 532 "SEC": "SECOND", 533 "SECONDS": "SECOND", 534 "SECS": "SECOND", 535 "MS": "MILLISECOND", 536 "MSEC": "MILLISECOND", 537 "MSECS": "MILLISECOND", 538 "MSECOND": "MILLISECOND", 539 "MSECONDS": "MILLISECOND", 540 "MILLISEC": "MILLISECOND", 541 "MILLISECS": "MILLISECOND", 542 "MILLISECON": "MILLISECOND", 543 "MILLISECONDS": "MILLISECOND", 544 "US": "MICROSECOND", 545 "USEC": "MICROSECOND", 546 "USECS": "MICROSECOND", 547 "MICROSEC": "MICROSECOND", 548 "MICROSECS": "MICROSECOND", 549 "USECOND": "MICROSECOND", 550 "USECONDS": "MICROSECOND", 551 "MICROSECONDS": "MICROSECOND", 552 "NS": "NANOSECOND", 553 "NSEC": "NANOSECOND", 554 "NANOSEC": "NANOSECOND", 555 "NSECOND": "NANOSECOND", 556 "NSECONDS": "NANOSECOND", 557 "NANOSECS": "NANOSECOND", 558 "EPOCH_SECOND": "EPOCH", 559 "EPOCH_SECONDS": "EPOCH", 560 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 561 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 562 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 563 "TZH": "TIMEZONE_HOUR", 564 "TZM": "TIMEZONE_MINUTE", 565 "DEC": "DECADE", 566 "DECS": "DECADE", 567 "DECADES": "DECADE", 568 "MIL": "MILLENIUM", 569 "MILS": "MILLENIUM", 570 "MILLENIA": "MILLENIUM", 571 "C": "CENTURY", 572 "CENT": "CENTURY", 573 "CENTS": "CENTURY", 574 "CENTURIES": "CENTURY", 575 } 576 577 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 578 exp.DataType.Type.BIGINT: { 579 exp.ApproxDistinct, 580 exp.ArraySize, 581 exp.Length, 582 }, 583 exp.DataType.Type.BOOLEAN: { 584 exp.Between, 585 exp.Boolean, 586 exp.In, 587 exp.RegexpLike, 588 }, 589 exp.DataType.Type.DATE: { 590 exp.CurrentDate, 591 exp.Date, 592 exp.DateFromParts, 593 exp.DateStrToDate, 594 exp.DiToDate, 595 exp.StrToDate, 596 exp.TimeStrToDate, 597 exp.TsOrDsToDate, 598 }, 599 exp.DataType.Type.DATETIME: { 600 exp.CurrentDatetime, 601 exp.Datetime, 602 exp.DatetimeAdd, 603 exp.DatetimeSub, 604 }, 605 exp.DataType.Type.DOUBLE: { 606 exp.ApproxQuantile, 607 exp.Avg, 608 exp.Exp, 609 exp.Ln, 610 exp.Log, 611 exp.Pow, 612 exp.Quantile, 613 exp.Round, 614 exp.SafeDivide, 615 exp.Sqrt, 616 exp.Stddev, 617 exp.StddevPop, 618 exp.StddevSamp, 619 exp.ToDouble, 620 exp.Variance, 621 exp.VariancePop, 622 }, 623 exp.DataType.Type.INT: { 624 exp.Ceil, 625 exp.DatetimeDiff, 626 exp.DateDiff, 627 exp.TimestampDiff, 628 exp.TimeDiff, 629 exp.DateToDi, 630 exp.Levenshtein, 631 exp.Sign, 632 exp.StrPosition, 633 exp.TsOrDiToDi, 634 }, 635 exp.DataType.Type.JSON: { 636 exp.ParseJSON, 637 }, 638 exp.DataType.Type.TIME: { 639 exp.Time, 640 }, 641 exp.DataType.Type.TIMESTAMP: { 642 exp.CurrentTime, 643 exp.CurrentTimestamp, 644 exp.StrToTime, 645 exp.TimeAdd, 646 exp.TimeStrToTime, 647 exp.TimeSub, 648 exp.TimestampAdd, 649 exp.TimestampSub, 650 exp.UnixToTime, 651 }, 652 exp.DataType.Type.TINYINT: { 653 exp.Day, 654 exp.Month, 655 exp.Week, 656 exp.Year, 657 exp.Quarter, 658 }, 659 exp.DataType.Type.VARCHAR: { 660 exp.ArrayConcat, 661 exp.Concat, 662 exp.ConcatWs, 663 exp.DateToDateStr, 664 exp.GroupConcat, 665 exp.Initcap, 666 exp.Lower, 667 exp.Substring, 668 exp.String, 669 exp.TimeToStr, 670 exp.TimeToTimeStr, 671 exp.Trim, 672 exp.TsOrDsToDateStr, 673 exp.UnixToStr, 674 exp.UnixToTimeStr, 675 exp.Upper, 676 }, 677 } 678 679 ANNOTATORS: AnnotatorsType = { 680 **{ 681 expr_type: lambda self, e: self._annotate_unary(e) 682 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 683 }, 684 **{ 685 expr_type: lambda self, e: self._annotate_binary(e) 686 for expr_type in subclasses(exp.__name__, exp.Binary) 687 }, 688 **{ 689 expr_type: _annotate_with_type_lambda(data_type) 690 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 691 for expr_type in expressions 692 }, 693 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 694 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 695 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 696 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 697 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 698 exp.Bracket: lambda self, e: self._annotate_bracket(e), 699 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 700 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 701 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 702 exp.Count: lambda self, e: self._annotate_with_type( 703 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 704 ), 705 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 706 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 707 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 708 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 709 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 710 exp.Div: lambda self, e: self._annotate_div(e), 711 exp.Dot: lambda self, e: self._annotate_dot(e), 712 exp.Explode: lambda self, e: self._annotate_explode(e), 713 exp.Extract: lambda self, e: self._annotate_extract(e), 714 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 715 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 716 e, exp.DataType.build("ARRAY<DATE>") 717 ), 718 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 719 e, exp.DataType.build("ARRAY<TIMESTAMP>") 720 ), 721 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 722 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 723 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 724 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 725 exp.Literal: lambda self, e: self._annotate_literal(e), 726 exp.Map: lambda self, e: self._annotate_map(e), 727 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 728 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 729 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 730 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 731 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 732 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 733 exp.Struct: lambda self, e: self._annotate_struct(e), 734 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 735 exp.Timestamp: lambda self, e: self._annotate_with_type( 736 e, 737 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 738 ), 739 exp.ToMap: lambda self, e: self._annotate_to_map(e), 740 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 741 exp.Unnest: lambda self, e: self._annotate_unnest(e), 742 exp.VarMap: lambda self, e: self._annotate_map(e), 743 } 744 745 @classmethod 746 def get_or_raise(cls, dialect: DialectType) -> Dialect: 747 """ 748 Look up a dialect in the global dialect registry and return it if it exists. 749 750 Args: 751 dialect: The target dialect. If this is a string, it can be optionally followed by 752 additional key-value pairs that are separated by commas and are used to specify 753 dialect settings, such as whether the dialect's identifiers are case-sensitive. 754 755 Example: 756 >>> dialect = dialect_class = get_or_raise("duckdb") 757 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 758 759 Returns: 760 The corresponding Dialect instance. 761 """ 762 763 if not dialect: 764 return cls() 765 if isinstance(dialect, _Dialect): 766 return dialect() 767 if isinstance(dialect, Dialect): 768 return dialect 769 if isinstance(dialect, str): 770 try: 771 dialect_name, *kv_strings = dialect.split(",") 772 kv_pairs = (kv.split("=") for kv in kv_strings) 773 kwargs = {} 774 for pair in kv_pairs: 775 key = pair[0].strip() 776 value: t.Union[bool | str | None] = None 777 778 if len(pair) == 1: 779 # Default initialize standalone settings to True 780 value = True 781 elif len(pair) == 2: 782 value = pair[1].strip() 783 784 kwargs[key] = to_bool(value) 785 786 except ValueError: 787 raise ValueError( 788 f"Invalid dialect format: '{dialect}'. " 789 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 790 ) 791 792 result = cls.get(dialect_name.strip()) 793 if not result: 794 from difflib import get_close_matches 795 796 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 797 if similar: 798 similar = f" Did you mean {similar}?" 799 800 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 801 802 return result(**kwargs) 803 804 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 805 806 @classmethod 807 def format_time( 808 cls, expression: t.Optional[str | exp.Expression] 809 ) -> t.Optional[exp.Expression]: 810 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 811 if isinstance(expression, str): 812 return exp.Literal.string( 813 # the time formats are quoted 814 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 815 ) 816 817 if expression and expression.is_string: 818 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 819 820 return expression 821 822 def __init__(self, **kwargs) -> None: 823 normalization_strategy = kwargs.pop("normalization_strategy", None) 824 825 if normalization_strategy is None: 826 self.normalization_strategy = self.NORMALIZATION_STRATEGY 827 else: 828 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 829 830 self.settings = kwargs 831 832 def __eq__(self, other: t.Any) -> bool: 833 # Does not currently take dialect state into account 834 return type(self) == other 835 836 def __hash__(self) -> int: 837 # Does not currently take dialect state into account 838 return hash(type(self)) 839 840 def normalize_identifier(self, expression: E) -> E: 841 """ 842 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 843 844 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 845 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 846 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 847 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 848 849 There are also dialects like Spark, which are case-insensitive even when quotes are 850 present, and dialects like MySQL, whose resolution rules match those employed by the 851 underlying operating system, for example they may always be case-sensitive in Linux. 852 853 Finally, the normalization behavior of some engines can even be controlled through flags, 854 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 855 856 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 857 that it can analyze queries in the optimizer and successfully capture their semantics. 858 """ 859 if ( 860 isinstance(expression, exp.Identifier) 861 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 862 and ( 863 not expression.quoted 864 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 865 ) 866 ): 867 expression.set( 868 "this", 869 ( 870 expression.this.upper() 871 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 872 else expression.this.lower() 873 ), 874 ) 875 876 return expression 877 878 def case_sensitive(self, text: str) -> bool: 879 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 880 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 881 return False 882 883 unsafe = ( 884 str.islower 885 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 886 else str.isupper 887 ) 888 return any(unsafe(char) for char in text) 889 890 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 891 """Checks if text can be identified given an identify option. 892 893 Args: 894 text: The text to check. 895 identify: 896 `"always"` or `True`: Always returns `True`. 897 `"safe"`: Only returns `True` if the identifier is case-insensitive. 898 899 Returns: 900 Whether the given text can be identified. 901 """ 902 if identify is True or identify == "always": 903 return True 904 905 if identify == "safe": 906 return not self.case_sensitive(text) 907 908 return False 909 910 def quote_identifier(self, expression: E, identify: bool = True) -> E: 911 """ 912 Adds quotes to a given identifier. 913 914 Args: 915 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 916 identify: If set to `False`, the quotes will only be added if the identifier is deemed 917 "unsafe", with respect to its characters and this dialect's normalization strategy. 918 """ 919 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 920 name = expression.this 921 expression.set( 922 "quoted", 923 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 924 ) 925 926 return expression 927 928 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 929 if isinstance(path, exp.Literal): 930 path_text = path.name 931 if path.is_number: 932 path_text = f"[{path_text}]" 933 try: 934 return parse_json_path(path_text, self) 935 except ParseError as e: 936 if self.STRICT_JSON_PATH_SYNTAX: 937 logger.warning(f"Invalid JSON path syntax. {str(e)}") 938 939 return path 940 941 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 942 return self.parser(**opts).parse(self.tokenize(sql), sql) 943 944 def parse_into( 945 self, expression_type: exp.IntoType, sql: str, **opts 946 ) -> t.List[t.Optional[exp.Expression]]: 947 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 948 949 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 950 return self.generator(**opts).generate(expression, copy=copy) 951 952 def transpile(self, sql: str, **opts) -> t.List[str]: 953 return [ 954 self.generate(expression, copy=False, **opts) if expression else "" 955 for expression in self.parse(sql) 956 ] 957 958 def tokenize(self, sql: str) -> t.List[Token]: 959 return self.tokenizer.tokenize(sql) 960 961 @property 962 def tokenizer(self) -> Tokenizer: 963 return self.tokenizer_class(dialect=self) 964 965 @property 966 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 967 return self.jsonpath_tokenizer_class(dialect=self) 968 969 def parser(self, **opts) -> Parser: 970 return self.parser_class(dialect=self, **opts) 971 972 def generator(self, **opts) -> Generator: 973 return self.generator_class(dialect=self, **opts)
822 def __init__(self, **kwargs) -> None: 823 normalization_strategy = kwargs.pop("normalization_strategy", None) 824 825 if normalization_strategy is None: 826 self.normalization_strategy = self.NORMALIZATION_STRATEGY 827 else: 828 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 829 830 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the name of the function should be preserved inside the node's metadata, can be useful for roundtripping deprecated vs new functions that share an AST node e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal is cast to x's type to match it instead.
Whether number literals can include underscores for better readability
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
745 @classmethod 746 def get_or_raise(cls, dialect: DialectType) -> Dialect: 747 """ 748 Look up a dialect in the global dialect registry and return it if it exists. 749 750 Args: 751 dialect: The target dialect. If this is a string, it can be optionally followed by 752 additional key-value pairs that are separated by commas and are used to specify 753 dialect settings, such as whether the dialect's identifiers are case-sensitive. 754 755 Example: 756 >>> dialect = dialect_class = get_or_raise("duckdb") 757 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 758 759 Returns: 760 The corresponding Dialect instance. 761 """ 762 763 if not dialect: 764 return cls() 765 if isinstance(dialect, _Dialect): 766 return dialect() 767 if isinstance(dialect, Dialect): 768 return dialect 769 if isinstance(dialect, str): 770 try: 771 dialect_name, *kv_strings = dialect.split(",") 772 kv_pairs = (kv.split("=") for kv in kv_strings) 773 kwargs = {} 774 for pair in kv_pairs: 775 key = pair[0].strip() 776 value: t.Union[bool | str | None] = None 777 778 if len(pair) == 1: 779 # Default initialize standalone settings to True 780 value = True 781 elif len(pair) == 2: 782 value = pair[1].strip() 783 784 kwargs[key] = to_bool(value) 785 786 except ValueError: 787 raise ValueError( 788 f"Invalid dialect format: '{dialect}'. " 789 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 790 ) 791 792 result = cls.get(dialect_name.strip()) 793 if not result: 794 from difflib import get_close_matches 795 796 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 797 if similar: 798 similar = f" Did you mean {similar}?" 799 800 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 801 802 return result(**kwargs) 803 804 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
806 @classmethod 807 def format_time( 808 cls, expression: t.Optional[str | exp.Expression] 809 ) -> t.Optional[exp.Expression]: 810 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 811 if isinstance(expression, str): 812 return exp.Literal.string( 813 # the time formats are quoted 814 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 815 ) 816 817 if expression and expression.is_string: 818 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 819 820 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
840 def normalize_identifier(self, expression: E) -> E: 841 """ 842 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 843 844 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 845 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 846 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 847 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 848 849 There are also dialects like Spark, which are case-insensitive even when quotes are 850 present, and dialects like MySQL, whose resolution rules match those employed by the 851 underlying operating system, for example they may always be case-sensitive in Linux. 852 853 Finally, the normalization behavior of some engines can even be controlled through flags, 854 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 855 856 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 857 that it can analyze queries in the optimizer and successfully capture their semantics. 858 """ 859 if ( 860 isinstance(expression, exp.Identifier) 861 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 862 and ( 863 not expression.quoted 864 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 865 ) 866 ): 867 expression.set( 868 "this", 869 ( 870 expression.this.upper() 871 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 872 else expression.this.lower() 873 ), 874 ) 875 876 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
878 def case_sensitive(self, text: str) -> bool: 879 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 880 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 881 return False 882 883 unsafe = ( 884 str.islower 885 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 886 else str.isupper 887 ) 888 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
890 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 891 """Checks if text can be identified given an identify option. 892 893 Args: 894 text: The text to check. 895 identify: 896 `"always"` or `True`: Always returns `True`. 897 `"safe"`: Only returns `True` if the identifier is case-insensitive. 898 899 Returns: 900 Whether the given text can be identified. 901 """ 902 if identify is True or identify == "always": 903 return True 904 905 if identify == "safe": 906 return not self.case_sensitive(text) 907 908 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
910 def quote_identifier(self, expression: E, identify: bool = True) -> E: 911 """ 912 Adds quotes to a given identifier. 913 914 Args: 915 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 916 identify: If set to `False`, the quotes will only be added if the identifier is deemed 917 "unsafe", with respect to its characters and this dialect's normalization strategy. 918 """ 919 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 920 name = expression.this 921 expression.set( 922 "quoted", 923 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 924 ) 925 926 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
928 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 929 if isinstance(path, exp.Literal): 930 path_text = path.name 931 if path.is_number: 932 path_text = f"[{path_text}]" 933 try: 934 return parse_json_path(path_text, self) 935 except ParseError as e: 936 if self.STRICT_JSON_PATH_SYNTAX: 937 logger.warning(f"Invalid JSON path syntax. {str(e)}") 938 939 return path
988def if_sql( 989 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 990) -> t.Callable[[Generator, exp.If], str]: 991 def _if_sql(self: Generator, expression: exp.If) -> str: 992 return self.func( 993 name, 994 expression.this, 995 expression.args.get("true"), 996 expression.args.get("false") or false_value, 997 ) 998 999 return _if_sql
1002def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1003 this = expression.this 1004 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1005 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1006 1007 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1071def strposition_sql( 1072 self: Generator, 1073 expression: exp.StrPosition, 1074 func_name: str = "STRPOS", 1075 supports_position: bool = False, 1076 supports_occurrence: bool = False, 1077 use_ansi_position: bool = True, 1078) -> str: 1079 string = expression.this 1080 substr = expression.args.get("substr") 1081 position = expression.args.get("position") 1082 occurrence = expression.args.get("occurrence") 1083 zero = exp.Literal.number(0) 1084 one = exp.Literal.number(1) 1085 1086 if supports_occurrence and occurrence and supports_position and not position: 1087 position = one 1088 1089 transpile_position = position and not supports_position 1090 if transpile_position: 1091 string = exp.Substring(this=string, start=position) 1092 1093 if func_name == "POSITION" and use_ansi_position: 1094 func = exp.Anonymous(this=func_name, expressions=[exp.In(this=substr, field=string)]) 1095 else: 1096 args = [substr, string] if func_name in ("LOCATE", "CHARINDEX") else [string, substr] 1097 if supports_position: 1098 args.append(position) 1099 if occurrence: 1100 if supports_occurrence: 1101 args.append(occurrence) 1102 else: 1103 self.unsupported(f"{func_name} does not support the occurrence parameter.") 1104 func = exp.Anonymous(this=func_name, expressions=args) 1105 1106 if transpile_position: 1107 func_with_offset = exp.Sub(this=func + position, expression=one) 1108 func_wrapped = exp.If(this=func.eq(zero), true=zero, false=func_with_offset) 1109 return self.sql(func_wrapped) 1110 1111 return self.sql(func)
1120def var_map_sql( 1121 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1122) -> str: 1123 keys = expression.args["keys"] 1124 values = expression.args["values"] 1125 1126 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1127 self.unsupported("Cannot convert array columns into map.") 1128 return self.func(map_func_name, keys, values) 1129 1130 args = [] 1131 for key, value in zip(keys.expressions, values.expressions): 1132 args.append(self.sql(key)) 1133 args.append(self.sql(value)) 1134 1135 return self.func(map_func_name, *args)
1138def build_formatted_time( 1139 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1140) -> t.Callable[[t.List], E]: 1141 """Helper used for time expressions. 1142 1143 Args: 1144 exp_class: the expression class to instantiate. 1145 dialect: target sql dialect. 1146 default: the default format, True being time. 1147 1148 Returns: 1149 A callable that can be used to return the appropriately formatted time expression. 1150 """ 1151 1152 def _builder(args: t.List): 1153 return exp_class( 1154 this=seq_get(args, 0), 1155 format=Dialect[dialect].format_time( 1156 seq_get(args, 1) 1157 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1158 ), 1159 ) 1160 1161 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1164def time_format( 1165 dialect: DialectType = None, 1166) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1167 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1168 """ 1169 Returns the time format for a given expression, unless it's equivalent 1170 to the default time format of the dialect of interest. 1171 """ 1172 time_format = self.format_time(expression) 1173 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1174 1175 return _time_format
1178def build_date_delta( 1179 exp_class: t.Type[E], 1180 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1181 default_unit: t.Optional[str] = "DAY", 1182) -> t.Callable[[t.List], E]: 1183 def _builder(args: t.List) -> E: 1184 unit_based = len(args) == 3 1185 this = args[2] if unit_based else seq_get(args, 0) 1186 unit = None 1187 if unit_based or default_unit: 1188 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1189 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1190 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1191 1192 return _builder
1195def build_date_delta_with_interval( 1196 expression_class: t.Type[E], 1197) -> t.Callable[[t.List], t.Optional[E]]: 1198 def _builder(args: t.List) -> t.Optional[E]: 1199 if len(args) < 2: 1200 return None 1201 1202 interval = args[1] 1203 1204 if not isinstance(interval, exp.Interval): 1205 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1206 1207 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1208 1209 return _builder
1212def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1213 unit = seq_get(args, 0) 1214 this = seq_get(args, 1) 1215 1216 if isinstance(this, exp.Cast) and this.is_type("date"): 1217 return exp.DateTrunc(unit=unit, this=this) 1218 return exp.TimestampTrunc(this=this, unit=unit)
1221def date_add_interval_sql( 1222 data_type: str, kind: str 1223) -> t.Callable[[Generator, exp.Expression], str]: 1224 def func(self: Generator, expression: exp.Expression) -> str: 1225 this = self.sql(expression, "this") 1226 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1227 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1228 1229 return func
1232def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1233 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1234 args = [unit_to_str(expression), expression.this] 1235 if zone: 1236 args.append(expression.args.get("zone")) 1237 return self.func("DATE_TRUNC", *args) 1238 1239 return _timestamptrunc_sql
1242def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1243 zone = expression.args.get("zone") 1244 if not zone: 1245 from sqlglot.optimizer.annotate_types import annotate_types 1246 1247 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1248 return self.sql(exp.cast(expression.this, target_type)) 1249 if zone.name.lower() in TIMEZONES: 1250 return self.sql( 1251 exp.AtTimeZone( 1252 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1253 zone=zone, 1254 ) 1255 ) 1256 return self.func("TIMESTAMP", expression.this, zone)
1259def no_time_sql(self: Generator, expression: exp.Time) -> str: 1260 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1261 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1262 expr = exp.cast( 1263 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1264 ) 1265 return self.sql(expr)
1268def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1269 this = expression.this 1270 expr = expression.expression 1271 1272 if expr.name.lower() in TIMEZONES: 1273 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1274 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1275 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1276 return self.sql(this) 1277 1278 this = exp.cast(this, exp.DataType.Type.DATE) 1279 expr = exp.cast(expr, exp.DataType.Type.TIME) 1280 1281 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1301def timestrtotime_sql( 1302 self: Generator, 1303 expression: exp.TimeStrToTime, 1304 include_precision: bool = False, 1305) -> str: 1306 datatype = exp.DataType.build( 1307 exp.DataType.Type.TIMESTAMPTZ 1308 if expression.args.get("zone") 1309 else exp.DataType.Type.TIMESTAMP 1310 ) 1311 1312 if isinstance(expression.this, exp.Literal) and include_precision: 1313 precision = subsecond_precision(expression.this.name) 1314 if precision > 0: 1315 datatype = exp.DataType.build( 1316 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1317 ) 1318 1319 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1327def encode_decode_sql( 1328 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1329) -> str: 1330 charset = expression.args.get("charset") 1331 if charset and charset.name.lower() != "utf-8": 1332 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1333 1334 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1347def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1348 cond = expression.this 1349 1350 if isinstance(expression.this, exp.Distinct): 1351 cond = expression.this.expressions[0] 1352 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1353 1354 return self.func("sum", exp.func("if", cond, 1, 0))
1357def trim_sql(self: Generator, expression: exp.Trim) -> str: 1358 target = self.sql(expression, "this") 1359 trim_type = self.sql(expression, "position") 1360 remove_chars = self.sql(expression, "expression") 1361 collation = self.sql(expression, "collation") 1362 1363 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1364 if not remove_chars: 1365 return self.trim_sql(expression) 1366 1367 trim_type = f"{trim_type} " if trim_type else "" 1368 remove_chars = f"{remove_chars} " if remove_chars else "" 1369 from_part = "FROM " if trim_type or remove_chars else "" 1370 collation = f" COLLATE {collation}" if collation else "" 1371 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1392@unsupported_args("position", "occurrence", "parameters") 1393def regexp_extract_sql( 1394 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1395) -> str: 1396 group = expression.args.get("group") 1397 1398 # Do not render group if it's the default value for this dialect 1399 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1400 group = None 1401 1402 return self.func(expression.sql_name(), expression.this, expression.expression, group)
1412def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1413 names = [] 1414 for agg in aggregations: 1415 if isinstance(agg, exp.Alias): 1416 names.append(agg.alias) 1417 else: 1418 """ 1419 This case corresponds to aggregations without aliases being used as suffixes 1420 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1421 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1422 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1423 """ 1424 agg_all_unquoted = agg.transform( 1425 lambda node: ( 1426 exp.Identifier(this=node.name, quoted=False) 1427 if isinstance(node, exp.Identifier) 1428 else node 1429 ) 1430 ) 1431 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1432 1433 return names
1473def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1474 @unsupported_args("count") 1475 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1476 return self.func(name, expression.this, expression.expression) 1477 1478 return _arg_max_or_min_sql
1481def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1482 this = expression.this.copy() 1483 1484 return_type = expression.return_type 1485 if return_type.is_type(exp.DataType.Type.DATE): 1486 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1487 # can truncate timestamp strings, because some dialects can't cast them to DATE 1488 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1489 1490 expression.this.replace(exp.cast(this, return_type)) 1491 return expression
1494def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1495 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1496 if cast and isinstance(expression, exp.TsOrDsAdd): 1497 expression = ts_or_ds_add_cast(expression) 1498 1499 return self.func( 1500 name, 1501 unit_to_var(expression), 1502 expression.expression, 1503 expression.this, 1504 ) 1505 1506 return _delta_sql
1509def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1510 unit = expression.args.get("unit") 1511 1512 if isinstance(unit, exp.Placeholder): 1513 return unit 1514 if unit: 1515 return exp.Literal.string(unit.name) 1516 return exp.Literal.string(default) if default else None
1546def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1547 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1548 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1549 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1550 1551 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1554def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1555 """Remove table refs from columns in when statements.""" 1556 alias = expression.this.args.get("alias") 1557 1558 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1559 return self.dialect.normalize_identifier(identifier).name if identifier else None 1560 1561 targets = {normalize(expression.this.this)} 1562 1563 if alias: 1564 targets.add(normalize(alias.this)) 1565 1566 for when in expression.args["whens"].expressions: 1567 # only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED 1568 # they are still valid in the <condition>, the right hand side of each UPDATE and the VALUES part 1569 # (not the column list) of the INSERT 1570 then: exp.Insert | exp.Update | None = when.args.get("then") 1571 if then: 1572 if isinstance(then, exp.Update): 1573 for equals in then.find_all(exp.EQ): 1574 equal_lhs = equals.this 1575 if ( 1576 isinstance(equal_lhs, exp.Column) 1577 and normalize(equal_lhs.args.get("table")) in targets 1578 ): 1579 equal_lhs.replace(exp.column(equal_lhs.this)) 1580 if isinstance(then, exp.Insert): 1581 column_list = then.this 1582 if isinstance(column_list, exp.Tuple): 1583 for column in column_list.expressions: 1584 if normalize(column.args.get("table")) in targets: 1585 column.replace(exp.column(column.this)) 1586 1587 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1590def build_json_extract_path( 1591 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1592) -> t.Callable[[t.List], F]: 1593 def _builder(args: t.List) -> F: 1594 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1595 for arg in args[1:]: 1596 if not isinstance(arg, exp.Literal): 1597 # We use the fallback parser because we can't really transpile non-literals safely 1598 return expr_type.from_arg_list(args) 1599 1600 text = arg.name 1601 if is_int(text): 1602 index = int(text) 1603 segments.append( 1604 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1605 ) 1606 else: 1607 segments.append(exp.JSONPathKey(this=text)) 1608 1609 # This is done to avoid failing in the expression validator due to the arg count 1610 del args[2:] 1611 return expr_type( 1612 this=seq_get(args, 0), 1613 expression=exp.JSONPath(expressions=segments), 1614 only_json_types=arrow_req_json_type, 1615 ) 1616 1617 return _builder
1620def json_extract_segments( 1621 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1622) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1623 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1624 path = expression.expression 1625 if not isinstance(path, exp.JSONPath): 1626 return rename_func(name)(self, expression) 1627 1628 escape = path.args.get("escape") 1629 1630 segments = [] 1631 for segment in path.expressions: 1632 path = self.sql(segment) 1633 if path: 1634 if isinstance(segment, exp.JSONPathPart) and ( 1635 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1636 ): 1637 if escape: 1638 path = self.escape_str(path) 1639 1640 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1641 1642 segments.append(path) 1643 1644 if op: 1645 return f" {op} ".join([self.sql(expression.this), *segments]) 1646 return self.func(name, expression.this, *segments) 1647 1648 return _json_extract_segments
1658def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1659 cond = expression.expression 1660 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1661 alias = cond.expressions[0] 1662 cond = cond.this 1663 elif isinstance(cond, exp.Predicate): 1664 alias = "_u" 1665 else: 1666 self.unsupported("Unsupported filter condition") 1667 return "" 1668 1669 unnest = exp.Unnest(expressions=[expression.this]) 1670 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1671 return self.sql(exp.Array(expressions=[filtered]))
1683def build_default_decimal_type( 1684 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1685) -> t.Callable[[exp.DataType], exp.DataType]: 1686 def _builder(dtype: exp.DataType) -> exp.DataType: 1687 if dtype.expressions or precision is None: 1688 return dtype 1689 1690 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1691 return exp.DataType.build(f"DECIMAL({params})") 1692 1693 return _builder
1696def build_timestamp_from_parts(args: t.List) -> exp.Func: 1697 if len(args) == 2: 1698 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1699 # so we parse this into Anonymous for now instead of introducing complexity 1700 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1701 1702 return exp.TimestampFromParts.from_arg_list(args)
1709def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1710 start = expression.args.get("start") 1711 end = expression.args.get("end") 1712 step = expression.args.get("step") 1713 1714 if isinstance(start, exp.Cast): 1715 target_type = start.to 1716 elif isinstance(end, exp.Cast): 1717 target_type = end.to 1718 else: 1719 target_type = None 1720 1721 if start and end and target_type and target_type.is_type("date", "timestamp"): 1722 if isinstance(start, exp.Cast) and target_type is start.to: 1723 end = exp.cast(end, target_type) 1724 else: 1725 start = exp.cast(start, target_type) 1726 1727 return self.func("SEQUENCE", start, end, step)
1730def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1731 def _builder(args: t.List, dialect: Dialect) -> E: 1732 return expr_type( 1733 this=seq_get(args, 0), 1734 expression=seq_get(args, 1), 1735 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1736 parameters=seq_get(args, 3), 1737 ) 1738 1739 return _builder
1742def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1743 if isinstance(expression.this, exp.Explode): 1744 return self.sql( 1745 exp.Join( 1746 this=exp.Unnest( 1747 expressions=[expression.this.this], 1748 alias=expression.args.get("alias"), 1749 offset=isinstance(expression.this, exp.Posexplode), 1750 ), 1751 kind="cross", 1752 ) 1753 ) 1754 return self.lateral_sql(expression)
1761def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1762 args = [] 1763 for unit, value in expression.args.items(): 1764 if isinstance(value, exp.Kwarg): 1765 value = value.expression 1766 1767 args.append(f"{value} {unit}") 1768 1769 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"