sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "athena", "presto", "trino"): 191 klass.generator_class.TRY_SUPPORTED = False 192 klass.generator_class.SUPPORTS_UESCAPE = False 193 194 if enum not in ("", "databricks", "hive", "spark", "spark2"): 195 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 196 for modifier in ("cluster", "distribute", "sort"): 197 modifier_transforms.pop(modifier, None) 198 199 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 200 201 if enum not in ("", "doris", "mysql"): 202 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 203 TokenType.STRAIGHT_JOIN, 204 } 205 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 209 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 210 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 211 TokenType.ANTI, 212 TokenType.SEMI, 213 } 214 215 return klass 216 217 218class Dialect(metaclass=_Dialect): 219 INDEX_OFFSET = 0 220 """The base index offset for arrays.""" 221 222 WEEK_OFFSET = 0 223 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 224 225 UNNEST_COLUMN_ONLY = False 226 """Whether `UNNEST` table aliases are treated as column aliases.""" 227 228 ALIAS_POST_TABLESAMPLE = False 229 """Whether the table alias comes after tablesample.""" 230 231 TABLESAMPLE_SIZE_IS_PERCENT = False 232 """Whether a size in the table sample clause represents percentage.""" 233 234 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 235 """Specifies the strategy according to which identifiers should be normalized.""" 236 237 IDENTIFIERS_CAN_START_WITH_DIGIT = False 238 """Whether an unquoted identifier can start with a digit.""" 239 240 DPIPE_IS_STRING_CONCAT = True 241 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 242 243 STRICT_STRING_CONCAT = False 244 """Whether `CONCAT`'s arguments must be strings.""" 245 246 SUPPORTS_USER_DEFINED_TYPES = True 247 """Whether user-defined data types are supported.""" 248 249 SUPPORTS_SEMI_ANTI_JOIN = True 250 """Whether `SEMI` or `ANTI` joins are supported.""" 251 252 SUPPORTS_COLUMN_JOIN_MARKS = False 253 """Whether the old-style outer join (+) syntax is supported.""" 254 255 COPY_PARAMS_ARE_CSV = True 256 """Separator of COPY statement parameters.""" 257 258 NORMALIZE_FUNCTIONS: bool | str = "upper" 259 """ 260 Determines how function names are going to be normalized. 261 Possible values: 262 "upper" or True: Convert names to uppercase. 263 "lower": Convert names to lowercase. 264 False: Disables function name normalization. 265 """ 266 267 LOG_BASE_FIRST: t.Optional[bool] = True 268 """ 269 Whether the base comes first in the `LOG` function. 270 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 271 """ 272 273 NULL_ORDERING = "nulls_are_small" 274 """ 275 Default `NULL` ordering method to use if not explicitly set. 276 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 277 """ 278 279 TYPED_DIVISION = False 280 """ 281 Whether the behavior of `a / b` depends on the types of `a` and `b`. 282 False means `a / b` is always float division. 283 True means `a / b` is integer division if both `a` and `b` are integers. 284 """ 285 286 SAFE_DIVISION = False 287 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 288 289 CONCAT_COALESCE = False 290 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 291 292 HEX_LOWERCASE = False 293 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 294 295 DATE_FORMAT = "'%Y-%m-%d'" 296 DATEINT_FORMAT = "'%Y%m%d'" 297 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 298 299 TIME_MAPPING: t.Dict[str, str] = {} 300 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 301 302 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 303 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 304 FORMAT_MAPPING: t.Dict[str, str] = {} 305 """ 306 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 307 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 308 """ 309 310 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 311 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 312 313 PSEUDOCOLUMNS: t.Set[str] = set() 314 """ 315 Columns that are auto-generated by the engine corresponding to this dialect. 316 For example, such columns may be excluded from `SELECT *` queries. 317 """ 318 319 PREFER_CTE_ALIAS_COLUMN = False 320 """ 321 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 322 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 323 any projection aliases in the subquery. 324 325 For example, 326 WITH y(c) AS ( 327 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 328 ) SELECT c FROM y; 329 330 will be rewritten as 331 332 WITH y(c) AS ( 333 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 334 ) SELECT c FROM y; 335 """ 336 337 COPY_PARAMS_ARE_CSV = True 338 """ 339 Whether COPY statement parameters are separated by comma or whitespace 340 """ 341 342 FORCE_EARLY_ALIAS_REF_EXPANSION = False 343 """ 344 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 345 346 For example: 347 WITH data AS ( 348 SELECT 349 1 AS id, 350 2 AS my_id 351 ) 352 SELECT 353 id AS my_id 354 FROM 355 data 356 WHERE 357 my_id = 1 358 GROUP BY 359 my_id, 360 HAVING 361 my_id = 1 362 363 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 364 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 365 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 366 - Clickhouse, which will forward the alias across the query i.e it resolves 367 to "WHERE id = 1 GROUP BY id HAVING id = 1" 368 """ 369 370 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 371 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 372 373 SUPPORTS_ORDER_BY_ALL = False 374 """ 375 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 376 """ 377 378 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 379 """ 380 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 381 as the former is of type INT[] vs the latter which is SUPER 382 """ 383 384 SUPPORTS_FIXED_SIZE_ARRAYS = False 385 """ 386 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 387 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 388 be interpreted as a subscript/index operator. 389 """ 390 391 STRICT_JSON_PATH_SYNTAX = True 392 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 393 394 ON_CONDITION_EMPTY_BEFORE_ERROR = True 395 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 396 397 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 398 """Whether ArrayAgg needs to filter NULL values.""" 399 400 REGEXP_EXTRACT_DEFAULT_GROUP = 0 401 """The default value for the capturing group.""" 402 403 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 404 exp.Except: True, 405 exp.Intersect: True, 406 exp.Union: True, 407 } 408 """ 409 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 410 must be explicitly specified. 411 """ 412 413 CREATABLE_KIND_MAPPING: dict[str, str] = {} 414 """ 415 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 416 equivalent of CREATE SCHEMA is CREATE DATABASE. 417 """ 418 419 # --- Autofilled --- 420 421 tokenizer_class = Tokenizer 422 jsonpath_tokenizer_class = JSONPathTokenizer 423 parser_class = Parser 424 generator_class = Generator 425 426 # A trie of the time_mapping keys 427 TIME_TRIE: t.Dict = {} 428 FORMAT_TRIE: t.Dict = {} 429 430 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 431 INVERSE_TIME_TRIE: t.Dict = {} 432 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 433 INVERSE_FORMAT_TRIE: t.Dict = {} 434 435 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 436 437 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 438 439 # Delimiters for string literals and identifiers 440 QUOTE_START = "'" 441 QUOTE_END = "'" 442 IDENTIFIER_START = '"' 443 IDENTIFIER_END = '"' 444 445 # Delimiters for bit, hex, byte and unicode literals 446 BIT_START: t.Optional[str] = None 447 BIT_END: t.Optional[str] = None 448 HEX_START: t.Optional[str] = None 449 HEX_END: t.Optional[str] = None 450 BYTE_START: t.Optional[str] = None 451 BYTE_END: t.Optional[str] = None 452 UNICODE_START: t.Optional[str] = None 453 UNICODE_END: t.Optional[str] = None 454 455 DATE_PART_MAPPING = { 456 "Y": "YEAR", 457 "YY": "YEAR", 458 "YYY": "YEAR", 459 "YYYY": "YEAR", 460 "YR": "YEAR", 461 "YEARS": "YEAR", 462 "YRS": "YEAR", 463 "MM": "MONTH", 464 "MON": "MONTH", 465 "MONS": "MONTH", 466 "MONTHS": "MONTH", 467 "D": "DAY", 468 "DD": "DAY", 469 "DAYS": "DAY", 470 "DAYOFMONTH": "DAY", 471 "DAY OF WEEK": "DAYOFWEEK", 472 "WEEKDAY": "DAYOFWEEK", 473 "DOW": "DAYOFWEEK", 474 "DW": "DAYOFWEEK", 475 "WEEKDAY_ISO": "DAYOFWEEKISO", 476 "DOW_ISO": "DAYOFWEEKISO", 477 "DW_ISO": "DAYOFWEEKISO", 478 "DAY OF YEAR": "DAYOFYEAR", 479 "DOY": "DAYOFYEAR", 480 "DY": "DAYOFYEAR", 481 "W": "WEEK", 482 "WK": "WEEK", 483 "WEEKOFYEAR": "WEEK", 484 "WOY": "WEEK", 485 "WY": "WEEK", 486 "WEEK_ISO": "WEEKISO", 487 "WEEKOFYEARISO": "WEEKISO", 488 "WEEKOFYEAR_ISO": "WEEKISO", 489 "Q": "QUARTER", 490 "QTR": "QUARTER", 491 "QTRS": "QUARTER", 492 "QUARTERS": "QUARTER", 493 "H": "HOUR", 494 "HH": "HOUR", 495 "HR": "HOUR", 496 "HOURS": "HOUR", 497 "HRS": "HOUR", 498 "M": "MINUTE", 499 "MI": "MINUTE", 500 "MIN": "MINUTE", 501 "MINUTES": "MINUTE", 502 "MINS": "MINUTE", 503 "S": "SECOND", 504 "SEC": "SECOND", 505 "SECONDS": "SECOND", 506 "SECS": "SECOND", 507 "MS": "MILLISECOND", 508 "MSEC": "MILLISECOND", 509 "MSECS": "MILLISECOND", 510 "MSECOND": "MILLISECOND", 511 "MSECONDS": "MILLISECOND", 512 "MILLISEC": "MILLISECOND", 513 "MILLISECS": "MILLISECOND", 514 "MILLISECON": "MILLISECOND", 515 "MILLISECONDS": "MILLISECOND", 516 "US": "MICROSECOND", 517 "USEC": "MICROSECOND", 518 "USECS": "MICROSECOND", 519 "MICROSEC": "MICROSECOND", 520 "MICROSECS": "MICROSECOND", 521 "USECOND": "MICROSECOND", 522 "USECONDS": "MICROSECOND", 523 "MICROSECONDS": "MICROSECOND", 524 "NS": "NANOSECOND", 525 "NSEC": "NANOSECOND", 526 "NANOSEC": "NANOSECOND", 527 "NSECOND": "NANOSECOND", 528 "NSECONDS": "NANOSECOND", 529 "NANOSECS": "NANOSECOND", 530 "EPOCH_SECOND": "EPOCH", 531 "EPOCH_SECONDS": "EPOCH", 532 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 533 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 534 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 535 "TZH": "TIMEZONE_HOUR", 536 "TZM": "TIMEZONE_MINUTE", 537 "DEC": "DECADE", 538 "DECS": "DECADE", 539 "DECADES": "DECADE", 540 "MIL": "MILLENIUM", 541 "MILS": "MILLENIUM", 542 "MILLENIA": "MILLENIUM", 543 "C": "CENTURY", 544 "CENT": "CENTURY", 545 "CENTS": "CENTURY", 546 "CENTURIES": "CENTURY", 547 } 548 549 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 550 exp.DataType.Type.BIGINT: { 551 exp.ApproxDistinct, 552 exp.ArraySize, 553 exp.Length, 554 }, 555 exp.DataType.Type.BOOLEAN: { 556 exp.Between, 557 exp.Boolean, 558 exp.In, 559 exp.RegexpLike, 560 }, 561 exp.DataType.Type.DATE: { 562 exp.CurrentDate, 563 exp.Date, 564 exp.DateFromParts, 565 exp.DateStrToDate, 566 exp.DiToDate, 567 exp.StrToDate, 568 exp.TimeStrToDate, 569 exp.TsOrDsToDate, 570 }, 571 exp.DataType.Type.DATETIME: { 572 exp.CurrentDatetime, 573 exp.Datetime, 574 exp.DatetimeAdd, 575 exp.DatetimeSub, 576 }, 577 exp.DataType.Type.DOUBLE: { 578 exp.ApproxQuantile, 579 exp.Avg, 580 exp.Exp, 581 exp.Ln, 582 exp.Log, 583 exp.Pow, 584 exp.Quantile, 585 exp.Round, 586 exp.SafeDivide, 587 exp.Sqrt, 588 exp.Stddev, 589 exp.StddevPop, 590 exp.StddevSamp, 591 exp.ToDouble, 592 exp.Variance, 593 exp.VariancePop, 594 }, 595 exp.DataType.Type.INT: { 596 exp.Ceil, 597 exp.DatetimeDiff, 598 exp.DateDiff, 599 exp.TimestampDiff, 600 exp.TimeDiff, 601 exp.DateToDi, 602 exp.Levenshtein, 603 exp.Sign, 604 exp.StrPosition, 605 exp.TsOrDiToDi, 606 }, 607 exp.DataType.Type.JSON: { 608 exp.ParseJSON, 609 }, 610 exp.DataType.Type.TIME: { 611 exp.Time, 612 }, 613 exp.DataType.Type.TIMESTAMP: { 614 exp.CurrentTime, 615 exp.CurrentTimestamp, 616 exp.StrToTime, 617 exp.TimeAdd, 618 exp.TimeStrToTime, 619 exp.TimeSub, 620 exp.TimestampAdd, 621 exp.TimestampSub, 622 exp.UnixToTime, 623 }, 624 exp.DataType.Type.TINYINT: { 625 exp.Day, 626 exp.Month, 627 exp.Week, 628 exp.Year, 629 exp.Quarter, 630 }, 631 exp.DataType.Type.VARCHAR: { 632 exp.ArrayConcat, 633 exp.Concat, 634 exp.ConcatWs, 635 exp.DateToDateStr, 636 exp.GroupConcat, 637 exp.Initcap, 638 exp.Lower, 639 exp.Substring, 640 exp.String, 641 exp.TimeToStr, 642 exp.TimeToTimeStr, 643 exp.Trim, 644 exp.TsOrDsToDateStr, 645 exp.UnixToStr, 646 exp.UnixToTimeStr, 647 exp.Upper, 648 }, 649 } 650 651 ANNOTATORS: AnnotatorsType = { 652 **{ 653 expr_type: lambda self, e: self._annotate_unary(e) 654 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 655 }, 656 **{ 657 expr_type: lambda self, e: self._annotate_binary(e) 658 for expr_type in subclasses(exp.__name__, exp.Binary) 659 }, 660 **{ 661 expr_type: _annotate_with_type_lambda(data_type) 662 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 663 for expr_type in expressions 664 }, 665 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 666 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 667 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 668 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 669 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 670 exp.Bracket: lambda self, e: self._annotate_bracket(e), 671 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 672 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 673 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 674 exp.Count: lambda self, e: self._annotate_with_type( 675 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 676 ), 677 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 678 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 679 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 680 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 681 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 682 exp.Div: lambda self, e: self._annotate_div(e), 683 exp.Dot: lambda self, e: self._annotate_dot(e), 684 exp.Explode: lambda self, e: self._annotate_explode(e), 685 exp.Extract: lambda self, e: self._annotate_extract(e), 686 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 687 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 688 e, exp.DataType.build("ARRAY<DATE>") 689 ), 690 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 691 e, exp.DataType.build("ARRAY<TIMESTAMP>") 692 ), 693 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 694 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 695 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 696 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 697 exp.Literal: lambda self, e: self._annotate_literal(e), 698 exp.Map: lambda self, e: self._annotate_map(e), 699 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 700 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 701 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 702 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 703 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 704 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 705 exp.Struct: lambda self, e: self._annotate_struct(e), 706 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 707 exp.Timestamp: lambda self, e: self._annotate_with_type( 708 e, 709 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 710 ), 711 exp.ToMap: lambda self, e: self._annotate_to_map(e), 712 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 713 exp.Unnest: lambda self, e: self._annotate_unnest(e), 714 exp.VarMap: lambda self, e: self._annotate_map(e), 715 } 716 717 @classmethod 718 def get_or_raise(cls, dialect: DialectType) -> Dialect: 719 """ 720 Look up a dialect in the global dialect registry and return it if it exists. 721 722 Args: 723 dialect: The target dialect. If this is a string, it can be optionally followed by 724 additional key-value pairs that are separated by commas and are used to specify 725 dialect settings, such as whether the dialect's identifiers are case-sensitive. 726 727 Example: 728 >>> dialect = dialect_class = get_or_raise("duckdb") 729 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 730 731 Returns: 732 The corresponding Dialect instance. 733 """ 734 735 if not dialect: 736 return cls() 737 if isinstance(dialect, _Dialect): 738 return dialect() 739 if isinstance(dialect, Dialect): 740 return dialect 741 if isinstance(dialect, str): 742 try: 743 dialect_name, *kv_strings = dialect.split(",") 744 kv_pairs = (kv.split("=") for kv in kv_strings) 745 kwargs = {} 746 for pair in kv_pairs: 747 key = pair[0].strip() 748 value: t.Union[bool | str | None] = None 749 750 if len(pair) == 1: 751 # Default initialize standalone settings to True 752 value = True 753 elif len(pair) == 2: 754 value = pair[1].strip() 755 756 # Coerce the value to boolean if it matches to the truthy/falsy values below 757 value_lower = value.lower() 758 if value_lower in ("true", "1"): 759 value = True 760 elif value_lower in ("false", "0"): 761 value = False 762 763 kwargs[key] = value 764 765 except ValueError: 766 raise ValueError( 767 f"Invalid dialect format: '{dialect}'. " 768 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 769 ) 770 771 result = cls.get(dialect_name.strip()) 772 if not result: 773 from difflib import get_close_matches 774 775 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 776 if similar: 777 similar = f" Did you mean {similar}?" 778 779 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 780 781 return result(**kwargs) 782 783 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 784 785 @classmethod 786 def format_time( 787 cls, expression: t.Optional[str | exp.Expression] 788 ) -> t.Optional[exp.Expression]: 789 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 790 if isinstance(expression, str): 791 return exp.Literal.string( 792 # the time formats are quoted 793 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 794 ) 795 796 if expression and expression.is_string: 797 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 798 799 return expression 800 801 def __init__(self, **kwargs) -> None: 802 normalization_strategy = kwargs.pop("normalization_strategy", None) 803 804 if normalization_strategy is None: 805 self.normalization_strategy = self.NORMALIZATION_STRATEGY 806 else: 807 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 808 809 self.settings = kwargs 810 811 def __eq__(self, other: t.Any) -> bool: 812 # Does not currently take dialect state into account 813 return type(self) == other 814 815 def __hash__(self) -> int: 816 # Does not currently take dialect state into account 817 return hash(type(self)) 818 819 def normalize_identifier(self, expression: E) -> E: 820 """ 821 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 822 823 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 824 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 825 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 826 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 827 828 There are also dialects like Spark, which are case-insensitive even when quotes are 829 present, and dialects like MySQL, whose resolution rules match those employed by the 830 underlying operating system, for example they may always be case-sensitive in Linux. 831 832 Finally, the normalization behavior of some engines can even be controlled through flags, 833 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 834 835 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 836 that it can analyze queries in the optimizer and successfully capture their semantics. 837 """ 838 if ( 839 isinstance(expression, exp.Identifier) 840 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 841 and ( 842 not expression.quoted 843 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 844 ) 845 ): 846 expression.set( 847 "this", 848 ( 849 expression.this.upper() 850 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 851 else expression.this.lower() 852 ), 853 ) 854 855 return expression 856 857 def case_sensitive(self, text: str) -> bool: 858 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 859 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 860 return False 861 862 unsafe = ( 863 str.islower 864 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 865 else str.isupper 866 ) 867 return any(unsafe(char) for char in text) 868 869 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 870 """Checks if text can be identified given an identify option. 871 872 Args: 873 text: The text to check. 874 identify: 875 `"always"` or `True`: Always returns `True`. 876 `"safe"`: Only returns `True` if the identifier is case-insensitive. 877 878 Returns: 879 Whether the given text can be identified. 880 """ 881 if identify is True or identify == "always": 882 return True 883 884 if identify == "safe": 885 return not self.case_sensitive(text) 886 887 return False 888 889 def quote_identifier(self, expression: E, identify: bool = True) -> E: 890 """ 891 Adds quotes to a given identifier. 892 893 Args: 894 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 895 identify: If set to `False`, the quotes will only be added if the identifier is deemed 896 "unsafe", with respect to its characters and this dialect's normalization strategy. 897 """ 898 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 899 name = expression.this 900 expression.set( 901 "quoted", 902 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 903 ) 904 905 return expression 906 907 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 908 if isinstance(path, exp.Literal): 909 path_text = path.name 910 if path.is_number: 911 path_text = f"[{path_text}]" 912 try: 913 return parse_json_path(path_text, self) 914 except ParseError as e: 915 if self.STRICT_JSON_PATH_SYNTAX: 916 logger.warning(f"Invalid JSON path syntax. {str(e)}") 917 918 return path 919 920 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 921 return self.parser(**opts).parse(self.tokenize(sql), sql) 922 923 def parse_into( 924 self, expression_type: exp.IntoType, sql: str, **opts 925 ) -> t.List[t.Optional[exp.Expression]]: 926 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 927 928 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 929 return self.generator(**opts).generate(expression, copy=copy) 930 931 def transpile(self, sql: str, **opts) -> t.List[str]: 932 return [ 933 self.generate(expression, copy=False, **opts) if expression else "" 934 for expression in self.parse(sql) 935 ] 936 937 def tokenize(self, sql: str) -> t.List[Token]: 938 return self.tokenizer.tokenize(sql) 939 940 @property 941 def tokenizer(self) -> Tokenizer: 942 return self.tokenizer_class(dialect=self) 943 944 @property 945 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 946 return self.jsonpath_tokenizer_class(dialect=self) 947 948 def parser(self, **opts) -> Parser: 949 return self.parser_class(dialect=self, **opts) 950 951 def generator(self, **opts) -> Generator: 952 return self.generator_class(dialect=self, **opts) 953 954 955DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 956 957 958def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 959 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 960 961 962@unsupported_args("accuracy") 963def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 964 return self.func("APPROX_COUNT_DISTINCT", expression.this) 965 966 967def if_sql( 968 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 969) -> t.Callable[[Generator, exp.If], str]: 970 def _if_sql(self: Generator, expression: exp.If) -> str: 971 return self.func( 972 name, 973 expression.this, 974 expression.args.get("true"), 975 expression.args.get("false") or false_value, 976 ) 977 978 return _if_sql 979 980 981def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 982 this = expression.this 983 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 984 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 985 986 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 987 988 989def inline_array_sql(self: Generator, expression: exp.Array) -> str: 990 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 991 992 993def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 994 elem = seq_get(expression.expressions, 0) 995 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 996 return self.func("ARRAY", elem) 997 return inline_array_sql(self, expression) 998 999 1000def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1001 return self.like_sql( 1002 exp.Like( 1003 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1004 ) 1005 ) 1006 1007 1008def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1009 zone = self.sql(expression, "this") 1010 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1011 1012 1013def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1014 if expression.args.get("recursive"): 1015 self.unsupported("Recursive CTEs are unsupported") 1016 expression.args["recursive"] = False 1017 return self.with_sql(expression) 1018 1019 1020def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide, if_sql: str = "IF") -> str: 1021 n = self.sql(expression, "this") 1022 d = self.sql(expression, "expression") 1023 return f"{if_sql}(({d}) <> 0, ({n}) / ({d}), NULL)" 1024 1025 1026def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1027 self.unsupported("TABLESAMPLE unsupported") 1028 return self.sql(expression.this) 1029 1030 1031def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1032 self.unsupported("PIVOT unsupported") 1033 return "" 1034 1035 1036def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1037 return self.cast_sql(expression) 1038 1039 1040def no_comment_column_constraint_sql( 1041 self: Generator, expression: exp.CommentColumnConstraint 1042) -> str: 1043 self.unsupported("CommentColumnConstraint unsupported") 1044 return "" 1045 1046 1047def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1048 self.unsupported("MAP_FROM_ENTRIES unsupported") 1049 return "" 1050 1051 1052def property_sql(self: Generator, expression: exp.Property) -> str: 1053 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1054 1055 1056def str_position_sql( 1057 self: Generator, 1058 expression: exp.StrPosition, 1059 generate_instance: bool = False, 1060 str_position_func_name: str = "STRPOS", 1061) -> str: 1062 this = self.sql(expression, "this") 1063 substr = self.sql(expression, "substr") 1064 position = self.sql(expression, "position") 1065 instance = expression.args.get("instance") if generate_instance else None 1066 position_offset = "" 1067 1068 if position: 1069 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1070 this = self.func("SUBSTR", this, position) 1071 position_offset = f" + {position} - 1" 1072 1073 return self.func(str_position_func_name, this, substr, instance) + position_offset 1074 1075 1076def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1077 return ( 1078 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1079 ) 1080 1081 1082def var_map_sql( 1083 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1084) -> str: 1085 keys = expression.args["keys"] 1086 values = expression.args["values"] 1087 1088 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1089 self.unsupported("Cannot convert array columns into map.") 1090 return self.func(map_func_name, keys, values) 1091 1092 args = [] 1093 for key, value in zip(keys.expressions, values.expressions): 1094 args.append(self.sql(key)) 1095 args.append(self.sql(value)) 1096 1097 return self.func(map_func_name, *args) 1098 1099 1100def build_formatted_time( 1101 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1102) -> t.Callable[[t.List], E]: 1103 """Helper used for time expressions. 1104 1105 Args: 1106 exp_class: the expression class to instantiate. 1107 dialect: target sql dialect. 1108 default: the default format, True being time. 1109 1110 Returns: 1111 A callable that can be used to return the appropriately formatted time expression. 1112 """ 1113 1114 def _builder(args: t.List): 1115 return exp_class( 1116 this=seq_get(args, 0), 1117 format=Dialect[dialect].format_time( 1118 seq_get(args, 1) 1119 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1120 ), 1121 ) 1122 1123 return _builder 1124 1125 1126def time_format( 1127 dialect: DialectType = None, 1128) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1129 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1130 """ 1131 Returns the time format for a given expression, unless it's equivalent 1132 to the default time format of the dialect of interest. 1133 """ 1134 time_format = self.format_time(expression) 1135 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1136 1137 return _time_format 1138 1139 1140def build_date_delta( 1141 exp_class: t.Type[E], 1142 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1143 default_unit: t.Optional[str] = "DAY", 1144) -> t.Callable[[t.List], E]: 1145 def _builder(args: t.List) -> E: 1146 unit_based = len(args) == 3 1147 this = args[2] if unit_based else seq_get(args, 0) 1148 unit = None 1149 if unit_based or default_unit: 1150 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1151 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1152 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1153 1154 return _builder 1155 1156 1157def build_date_delta_with_interval( 1158 expression_class: t.Type[E], 1159) -> t.Callable[[t.List], t.Optional[E]]: 1160 def _builder(args: t.List) -> t.Optional[E]: 1161 if len(args) < 2: 1162 return None 1163 1164 interval = args[1] 1165 1166 if not isinstance(interval, exp.Interval): 1167 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1168 1169 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1170 1171 return _builder 1172 1173 1174def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1175 unit = seq_get(args, 0) 1176 this = seq_get(args, 1) 1177 1178 if isinstance(this, exp.Cast) and this.is_type("date"): 1179 return exp.DateTrunc(unit=unit, this=this) 1180 return exp.TimestampTrunc(this=this, unit=unit) 1181 1182 1183def date_add_interval_sql( 1184 data_type: str, kind: str 1185) -> t.Callable[[Generator, exp.Expression], str]: 1186 def func(self: Generator, expression: exp.Expression) -> str: 1187 this = self.sql(expression, "this") 1188 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1189 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1190 1191 return func 1192 1193 1194def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1195 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1196 args = [unit_to_str(expression), expression.this] 1197 if zone: 1198 args.append(expression.args.get("zone")) 1199 return self.func("DATE_TRUNC", *args) 1200 1201 return _timestamptrunc_sql 1202 1203 1204def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1205 zone = expression.args.get("zone") 1206 if not zone: 1207 from sqlglot.optimizer.annotate_types import annotate_types 1208 1209 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1210 return self.sql(exp.cast(expression.this, target_type)) 1211 if zone.name.lower() in TIMEZONES: 1212 return self.sql( 1213 exp.AtTimeZone( 1214 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1215 zone=zone, 1216 ) 1217 ) 1218 return self.func("TIMESTAMP", expression.this, zone) 1219 1220 1221def no_time_sql(self: Generator, expression: exp.Time) -> str: 1222 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1223 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1224 expr = exp.cast( 1225 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1226 ) 1227 return self.sql(expr) 1228 1229 1230def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1231 this = expression.this 1232 expr = expression.expression 1233 1234 if expr.name.lower() in TIMEZONES: 1235 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1236 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1237 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1238 return self.sql(this) 1239 1240 this = exp.cast(this, exp.DataType.Type.DATE) 1241 expr = exp.cast(expr, exp.DataType.Type.TIME) 1242 1243 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1244 1245 1246def locate_to_strposition(args: t.List) -> exp.Expression: 1247 return exp.StrPosition( 1248 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1249 ) 1250 1251 1252def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1253 return self.func( 1254 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1255 ) 1256 1257 1258def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1259 return self.sql( 1260 exp.Substring( 1261 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1262 ) 1263 ) 1264 1265 1266def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1267 return self.sql( 1268 exp.Substring( 1269 this=expression.this, 1270 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1271 ) 1272 ) 1273 1274 1275def timestrtotime_sql( 1276 self: Generator, 1277 expression: exp.TimeStrToTime, 1278 include_precision: bool = False, 1279) -> str: 1280 datatype = exp.DataType.build( 1281 exp.DataType.Type.TIMESTAMPTZ 1282 if expression.args.get("zone") 1283 else exp.DataType.Type.TIMESTAMP 1284 ) 1285 1286 if isinstance(expression.this, exp.Literal) and include_precision: 1287 precision = subsecond_precision(expression.this.name) 1288 if precision > 0: 1289 datatype = exp.DataType.build( 1290 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1291 ) 1292 1293 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1294 1295 1296def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1297 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1298 1299 1300# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1301def encode_decode_sql( 1302 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1303) -> str: 1304 charset = expression.args.get("charset") 1305 if charset and charset.name.lower() != "utf-8": 1306 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1307 1308 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1309 1310 1311def min_or_least(self: Generator, expression: exp.Min) -> str: 1312 name = "LEAST" if expression.expressions else "MIN" 1313 return rename_func(name)(self, expression) 1314 1315 1316def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1317 name = "GREATEST" if expression.expressions else "MAX" 1318 return rename_func(name)(self, expression) 1319 1320 1321def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1322 cond = expression.this 1323 1324 if isinstance(expression.this, exp.Distinct): 1325 cond = expression.this.expressions[0] 1326 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1327 1328 return self.func("sum", exp.func("if", cond, 1, 0)) 1329 1330 1331def trim_sql(self: Generator, expression: exp.Trim) -> str: 1332 target = self.sql(expression, "this") 1333 trim_type = self.sql(expression, "position") 1334 remove_chars = self.sql(expression, "expression") 1335 collation = self.sql(expression, "collation") 1336 1337 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1338 if not remove_chars: 1339 return self.trim_sql(expression) 1340 1341 trim_type = f"{trim_type} " if trim_type else "" 1342 remove_chars = f"{remove_chars} " if remove_chars else "" 1343 from_part = "FROM " if trim_type or remove_chars else "" 1344 collation = f" COLLATE {collation}" if collation else "" 1345 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1346 1347 1348def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1349 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1350 1351 1352def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1353 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1354 1355 1356def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1357 delim, *rest_args = expression.expressions 1358 return self.sql( 1359 reduce( 1360 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1361 rest_args, 1362 ) 1363 ) 1364 1365 1366@unsupported_args("position", "occurrence", "parameters") 1367def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1368 group = expression.args.get("group") 1369 1370 # Do not render group if it's the default value for this dialect 1371 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1372 group = None 1373 1374 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group) 1375 1376 1377@unsupported_args("position", "occurrence", "modifiers") 1378def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1379 return self.func( 1380 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1381 ) 1382 1383 1384def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1385 names = [] 1386 for agg in aggregations: 1387 if isinstance(agg, exp.Alias): 1388 names.append(agg.alias) 1389 else: 1390 """ 1391 This case corresponds to aggregations without aliases being used as suffixes 1392 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1393 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1394 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1395 """ 1396 agg_all_unquoted = agg.transform( 1397 lambda node: ( 1398 exp.Identifier(this=node.name, quoted=False) 1399 if isinstance(node, exp.Identifier) 1400 else node 1401 ) 1402 ) 1403 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1404 1405 return names 1406 1407 1408def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1409 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1410 1411 1412# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1413def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1414 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1415 1416 1417def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1418 return self.func("MAX", expression.this) 1419 1420 1421def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1422 a = self.sql(expression.left) 1423 b = self.sql(expression.right) 1424 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1425 1426 1427def is_parse_json(expression: exp.Expression) -> bool: 1428 return isinstance(expression, exp.ParseJSON) or ( 1429 isinstance(expression, exp.Cast) and expression.is_type("json") 1430 ) 1431 1432 1433def isnull_to_is_null(args: t.List) -> exp.Expression: 1434 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1435 1436 1437def generatedasidentitycolumnconstraint_sql( 1438 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1439) -> str: 1440 start = self.sql(expression, "start") or "1" 1441 increment = self.sql(expression, "increment") or "1" 1442 return f"IDENTITY({start}, {increment})" 1443 1444 1445def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1446 @unsupported_args("count") 1447 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1448 return self.func(name, expression.this, expression.expression) 1449 1450 return _arg_max_or_min_sql 1451 1452 1453def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1454 this = expression.this.copy() 1455 1456 return_type = expression.return_type 1457 if return_type.is_type(exp.DataType.Type.DATE): 1458 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1459 # can truncate timestamp strings, because some dialects can't cast them to DATE 1460 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1461 1462 expression.this.replace(exp.cast(this, return_type)) 1463 return expression 1464 1465 1466def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1467 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1468 if cast and isinstance(expression, exp.TsOrDsAdd): 1469 expression = ts_or_ds_add_cast(expression) 1470 1471 return self.func( 1472 name, 1473 unit_to_var(expression), 1474 expression.expression, 1475 expression.this, 1476 ) 1477 1478 return _delta_sql 1479 1480 1481def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1482 unit = expression.args.get("unit") 1483 1484 if isinstance(unit, exp.Placeholder): 1485 return unit 1486 if unit: 1487 return exp.Literal.string(unit.name) 1488 return exp.Literal.string(default) if default else None 1489 1490 1491def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1492 unit = expression.args.get("unit") 1493 1494 if isinstance(unit, (exp.Var, exp.Placeholder)): 1495 return unit 1496 return exp.Var(this=default) if default else None 1497 1498 1499@t.overload 1500def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1501 pass 1502 1503 1504@t.overload 1505def map_date_part( 1506 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1507) -> t.Optional[exp.Expression]: 1508 pass 1509 1510 1511def map_date_part(part, dialect: DialectType = Dialect): 1512 mapped = ( 1513 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1514 ) 1515 return exp.var(mapped) if mapped else part 1516 1517 1518def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1519 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1520 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1521 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1522 1523 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1524 1525 1526def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1527 """Remove table refs from columns in when statements.""" 1528 alias = expression.this.args.get("alias") 1529 1530 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1531 return self.dialect.normalize_identifier(identifier).name if identifier else None 1532 1533 targets = {normalize(expression.this.this)} 1534 1535 if alias: 1536 targets.add(normalize(alias.this)) 1537 1538 for when in expression.expressions: 1539 # only remove the target names from the THEN clause 1540 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1541 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1542 then = when.args.get("then") 1543 if then: 1544 then.transform( 1545 lambda node: ( 1546 exp.column(node.this) 1547 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1548 else node 1549 ), 1550 copy=False, 1551 ) 1552 1553 return self.merge_sql(expression) 1554 1555 1556def build_json_extract_path( 1557 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1558) -> t.Callable[[t.List], F]: 1559 def _builder(args: t.List) -> F: 1560 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1561 for arg in args[1:]: 1562 if not isinstance(arg, exp.Literal): 1563 # We use the fallback parser because we can't really transpile non-literals safely 1564 return expr_type.from_arg_list(args) 1565 1566 text = arg.name 1567 if is_int(text): 1568 index = int(text) 1569 segments.append( 1570 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1571 ) 1572 else: 1573 segments.append(exp.JSONPathKey(this=text)) 1574 1575 # This is done to avoid failing in the expression validator due to the arg count 1576 del args[2:] 1577 return expr_type( 1578 this=seq_get(args, 0), 1579 expression=exp.JSONPath(expressions=segments), 1580 only_json_types=arrow_req_json_type, 1581 ) 1582 1583 return _builder 1584 1585 1586def json_extract_segments( 1587 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1588) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1589 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1590 path = expression.expression 1591 if not isinstance(path, exp.JSONPath): 1592 return rename_func(name)(self, expression) 1593 1594 escape = path.args.get("escape") 1595 1596 segments = [] 1597 for segment in path.expressions: 1598 path = self.sql(segment) 1599 if path: 1600 if isinstance(segment, exp.JSONPathPart) and ( 1601 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1602 ): 1603 if escape: 1604 path = self.escape_str(path) 1605 1606 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1607 1608 segments.append(path) 1609 1610 if op: 1611 return f" {op} ".join([self.sql(expression.this), *segments]) 1612 return self.func(name, expression.this, *segments) 1613 1614 return _json_extract_segments 1615 1616 1617def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1618 if isinstance(expression.this, exp.JSONPathWildcard): 1619 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1620 1621 return expression.name 1622 1623 1624def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1625 cond = expression.expression 1626 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1627 alias = cond.expressions[0] 1628 cond = cond.this 1629 elif isinstance(cond, exp.Predicate): 1630 alias = "_u" 1631 else: 1632 self.unsupported("Unsupported filter condition") 1633 return "" 1634 1635 unnest = exp.Unnest(expressions=[expression.this]) 1636 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1637 return self.sql(exp.Array(expressions=[filtered])) 1638 1639 1640def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1641 return self.func( 1642 "TO_NUMBER", 1643 expression.this, 1644 expression.args.get("format"), 1645 expression.args.get("nlsparam"), 1646 ) 1647 1648 1649def build_default_decimal_type( 1650 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1651) -> t.Callable[[exp.DataType], exp.DataType]: 1652 def _builder(dtype: exp.DataType) -> exp.DataType: 1653 if dtype.expressions or precision is None: 1654 return dtype 1655 1656 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1657 return exp.DataType.build(f"DECIMAL({params})") 1658 1659 return _builder 1660 1661 1662def build_timestamp_from_parts(args: t.List) -> exp.Func: 1663 if len(args) == 2: 1664 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1665 # so we parse this into Anonymous for now instead of introducing complexity 1666 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1667 1668 return exp.TimestampFromParts.from_arg_list(args) 1669 1670 1671def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1672 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1673 1674 1675def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1676 start = expression.args.get("start") 1677 end = expression.args.get("end") 1678 step = expression.args.get("step") 1679 1680 if isinstance(start, exp.Cast): 1681 target_type = start.to 1682 elif isinstance(end, exp.Cast): 1683 target_type = end.to 1684 else: 1685 target_type = None 1686 1687 if start and end and target_type and target_type.is_type("date", "timestamp"): 1688 if isinstance(start, exp.Cast) and target_type is start.to: 1689 end = exp.cast(end, target_type) 1690 else: 1691 start = exp.cast(start, target_type) 1692 1693 return self.func("SEQUENCE", start, end, step) 1694 1695 1696def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract: 1697 return exp.RegexpExtract( 1698 this=seq_get(args, 0), 1699 expression=seq_get(args, 1), 1700 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1701 ) 1702 1703 1704def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1705 if isinstance(expression.this, exp.Explode): 1706 return self.sql( 1707 exp.Join( 1708 this=exp.Unnest( 1709 expressions=[expression.this.this], 1710 alias=expression.args.get("alias"), 1711 offset=isinstance(expression.this, exp.Posexplode), 1712 ), 1713 kind="cross", 1714 ) 1715 ) 1716 return self.lateral_sql(expression)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
219class Dialect(metaclass=_Dialect): 220 INDEX_OFFSET = 0 221 """The base index offset for arrays.""" 222 223 WEEK_OFFSET = 0 224 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 225 226 UNNEST_COLUMN_ONLY = False 227 """Whether `UNNEST` table aliases are treated as column aliases.""" 228 229 ALIAS_POST_TABLESAMPLE = False 230 """Whether the table alias comes after tablesample.""" 231 232 TABLESAMPLE_SIZE_IS_PERCENT = False 233 """Whether a size in the table sample clause represents percentage.""" 234 235 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 236 """Specifies the strategy according to which identifiers should be normalized.""" 237 238 IDENTIFIERS_CAN_START_WITH_DIGIT = False 239 """Whether an unquoted identifier can start with a digit.""" 240 241 DPIPE_IS_STRING_CONCAT = True 242 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 243 244 STRICT_STRING_CONCAT = False 245 """Whether `CONCAT`'s arguments must be strings.""" 246 247 SUPPORTS_USER_DEFINED_TYPES = True 248 """Whether user-defined data types are supported.""" 249 250 SUPPORTS_SEMI_ANTI_JOIN = True 251 """Whether `SEMI` or `ANTI` joins are supported.""" 252 253 SUPPORTS_COLUMN_JOIN_MARKS = False 254 """Whether the old-style outer join (+) syntax is supported.""" 255 256 COPY_PARAMS_ARE_CSV = True 257 """Separator of COPY statement parameters.""" 258 259 NORMALIZE_FUNCTIONS: bool | str = "upper" 260 """ 261 Determines how function names are going to be normalized. 262 Possible values: 263 "upper" or True: Convert names to uppercase. 264 "lower": Convert names to lowercase. 265 False: Disables function name normalization. 266 """ 267 268 LOG_BASE_FIRST: t.Optional[bool] = True 269 """ 270 Whether the base comes first in the `LOG` function. 271 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 272 """ 273 274 NULL_ORDERING = "nulls_are_small" 275 """ 276 Default `NULL` ordering method to use if not explicitly set. 277 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 278 """ 279 280 TYPED_DIVISION = False 281 """ 282 Whether the behavior of `a / b` depends on the types of `a` and `b`. 283 False means `a / b` is always float division. 284 True means `a / b` is integer division if both `a` and `b` are integers. 285 """ 286 287 SAFE_DIVISION = False 288 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 289 290 CONCAT_COALESCE = False 291 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 292 293 HEX_LOWERCASE = False 294 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 295 296 DATE_FORMAT = "'%Y-%m-%d'" 297 DATEINT_FORMAT = "'%Y%m%d'" 298 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 299 300 TIME_MAPPING: t.Dict[str, str] = {} 301 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 302 303 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 304 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 305 FORMAT_MAPPING: t.Dict[str, str] = {} 306 """ 307 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 308 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 309 """ 310 311 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 312 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 313 314 PSEUDOCOLUMNS: t.Set[str] = set() 315 """ 316 Columns that are auto-generated by the engine corresponding to this dialect. 317 For example, such columns may be excluded from `SELECT *` queries. 318 """ 319 320 PREFER_CTE_ALIAS_COLUMN = False 321 """ 322 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 323 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 324 any projection aliases in the subquery. 325 326 For example, 327 WITH y(c) AS ( 328 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 329 ) SELECT c FROM y; 330 331 will be rewritten as 332 333 WITH y(c) AS ( 334 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 335 ) SELECT c FROM y; 336 """ 337 338 COPY_PARAMS_ARE_CSV = True 339 """ 340 Whether COPY statement parameters are separated by comma or whitespace 341 """ 342 343 FORCE_EARLY_ALIAS_REF_EXPANSION = False 344 """ 345 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 346 347 For example: 348 WITH data AS ( 349 SELECT 350 1 AS id, 351 2 AS my_id 352 ) 353 SELECT 354 id AS my_id 355 FROM 356 data 357 WHERE 358 my_id = 1 359 GROUP BY 360 my_id, 361 HAVING 362 my_id = 1 363 364 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 365 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 366 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 367 - Clickhouse, which will forward the alias across the query i.e it resolves 368 to "WHERE id = 1 GROUP BY id HAVING id = 1" 369 """ 370 371 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 372 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 373 374 SUPPORTS_ORDER_BY_ALL = False 375 """ 376 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 377 """ 378 379 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 380 """ 381 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 382 as the former is of type INT[] vs the latter which is SUPER 383 """ 384 385 SUPPORTS_FIXED_SIZE_ARRAYS = False 386 """ 387 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 388 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 389 be interpreted as a subscript/index operator. 390 """ 391 392 STRICT_JSON_PATH_SYNTAX = True 393 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 394 395 ON_CONDITION_EMPTY_BEFORE_ERROR = True 396 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 397 398 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 399 """Whether ArrayAgg needs to filter NULL values.""" 400 401 REGEXP_EXTRACT_DEFAULT_GROUP = 0 402 """The default value for the capturing group.""" 403 404 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 405 exp.Except: True, 406 exp.Intersect: True, 407 exp.Union: True, 408 } 409 """ 410 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 411 must be explicitly specified. 412 """ 413 414 CREATABLE_KIND_MAPPING: dict[str, str] = {} 415 """ 416 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 417 equivalent of CREATE SCHEMA is CREATE DATABASE. 418 """ 419 420 # --- Autofilled --- 421 422 tokenizer_class = Tokenizer 423 jsonpath_tokenizer_class = JSONPathTokenizer 424 parser_class = Parser 425 generator_class = Generator 426 427 # A trie of the time_mapping keys 428 TIME_TRIE: t.Dict = {} 429 FORMAT_TRIE: t.Dict = {} 430 431 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 432 INVERSE_TIME_TRIE: t.Dict = {} 433 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 434 INVERSE_FORMAT_TRIE: t.Dict = {} 435 436 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 437 438 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 439 440 # Delimiters for string literals and identifiers 441 QUOTE_START = "'" 442 QUOTE_END = "'" 443 IDENTIFIER_START = '"' 444 IDENTIFIER_END = '"' 445 446 # Delimiters for bit, hex, byte and unicode literals 447 BIT_START: t.Optional[str] = None 448 BIT_END: t.Optional[str] = None 449 HEX_START: t.Optional[str] = None 450 HEX_END: t.Optional[str] = None 451 BYTE_START: t.Optional[str] = None 452 BYTE_END: t.Optional[str] = None 453 UNICODE_START: t.Optional[str] = None 454 UNICODE_END: t.Optional[str] = None 455 456 DATE_PART_MAPPING = { 457 "Y": "YEAR", 458 "YY": "YEAR", 459 "YYY": "YEAR", 460 "YYYY": "YEAR", 461 "YR": "YEAR", 462 "YEARS": "YEAR", 463 "YRS": "YEAR", 464 "MM": "MONTH", 465 "MON": "MONTH", 466 "MONS": "MONTH", 467 "MONTHS": "MONTH", 468 "D": "DAY", 469 "DD": "DAY", 470 "DAYS": "DAY", 471 "DAYOFMONTH": "DAY", 472 "DAY OF WEEK": "DAYOFWEEK", 473 "WEEKDAY": "DAYOFWEEK", 474 "DOW": "DAYOFWEEK", 475 "DW": "DAYOFWEEK", 476 "WEEKDAY_ISO": "DAYOFWEEKISO", 477 "DOW_ISO": "DAYOFWEEKISO", 478 "DW_ISO": "DAYOFWEEKISO", 479 "DAY OF YEAR": "DAYOFYEAR", 480 "DOY": "DAYOFYEAR", 481 "DY": "DAYOFYEAR", 482 "W": "WEEK", 483 "WK": "WEEK", 484 "WEEKOFYEAR": "WEEK", 485 "WOY": "WEEK", 486 "WY": "WEEK", 487 "WEEK_ISO": "WEEKISO", 488 "WEEKOFYEARISO": "WEEKISO", 489 "WEEKOFYEAR_ISO": "WEEKISO", 490 "Q": "QUARTER", 491 "QTR": "QUARTER", 492 "QTRS": "QUARTER", 493 "QUARTERS": "QUARTER", 494 "H": "HOUR", 495 "HH": "HOUR", 496 "HR": "HOUR", 497 "HOURS": "HOUR", 498 "HRS": "HOUR", 499 "M": "MINUTE", 500 "MI": "MINUTE", 501 "MIN": "MINUTE", 502 "MINUTES": "MINUTE", 503 "MINS": "MINUTE", 504 "S": "SECOND", 505 "SEC": "SECOND", 506 "SECONDS": "SECOND", 507 "SECS": "SECOND", 508 "MS": "MILLISECOND", 509 "MSEC": "MILLISECOND", 510 "MSECS": "MILLISECOND", 511 "MSECOND": "MILLISECOND", 512 "MSECONDS": "MILLISECOND", 513 "MILLISEC": "MILLISECOND", 514 "MILLISECS": "MILLISECOND", 515 "MILLISECON": "MILLISECOND", 516 "MILLISECONDS": "MILLISECOND", 517 "US": "MICROSECOND", 518 "USEC": "MICROSECOND", 519 "USECS": "MICROSECOND", 520 "MICROSEC": "MICROSECOND", 521 "MICROSECS": "MICROSECOND", 522 "USECOND": "MICROSECOND", 523 "USECONDS": "MICROSECOND", 524 "MICROSECONDS": "MICROSECOND", 525 "NS": "NANOSECOND", 526 "NSEC": "NANOSECOND", 527 "NANOSEC": "NANOSECOND", 528 "NSECOND": "NANOSECOND", 529 "NSECONDS": "NANOSECOND", 530 "NANOSECS": "NANOSECOND", 531 "EPOCH_SECOND": "EPOCH", 532 "EPOCH_SECONDS": "EPOCH", 533 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 534 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 535 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 536 "TZH": "TIMEZONE_HOUR", 537 "TZM": "TIMEZONE_MINUTE", 538 "DEC": "DECADE", 539 "DECS": "DECADE", 540 "DECADES": "DECADE", 541 "MIL": "MILLENIUM", 542 "MILS": "MILLENIUM", 543 "MILLENIA": "MILLENIUM", 544 "C": "CENTURY", 545 "CENT": "CENTURY", 546 "CENTS": "CENTURY", 547 "CENTURIES": "CENTURY", 548 } 549 550 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 551 exp.DataType.Type.BIGINT: { 552 exp.ApproxDistinct, 553 exp.ArraySize, 554 exp.Length, 555 }, 556 exp.DataType.Type.BOOLEAN: { 557 exp.Between, 558 exp.Boolean, 559 exp.In, 560 exp.RegexpLike, 561 }, 562 exp.DataType.Type.DATE: { 563 exp.CurrentDate, 564 exp.Date, 565 exp.DateFromParts, 566 exp.DateStrToDate, 567 exp.DiToDate, 568 exp.StrToDate, 569 exp.TimeStrToDate, 570 exp.TsOrDsToDate, 571 }, 572 exp.DataType.Type.DATETIME: { 573 exp.CurrentDatetime, 574 exp.Datetime, 575 exp.DatetimeAdd, 576 exp.DatetimeSub, 577 }, 578 exp.DataType.Type.DOUBLE: { 579 exp.ApproxQuantile, 580 exp.Avg, 581 exp.Exp, 582 exp.Ln, 583 exp.Log, 584 exp.Pow, 585 exp.Quantile, 586 exp.Round, 587 exp.SafeDivide, 588 exp.Sqrt, 589 exp.Stddev, 590 exp.StddevPop, 591 exp.StddevSamp, 592 exp.ToDouble, 593 exp.Variance, 594 exp.VariancePop, 595 }, 596 exp.DataType.Type.INT: { 597 exp.Ceil, 598 exp.DatetimeDiff, 599 exp.DateDiff, 600 exp.TimestampDiff, 601 exp.TimeDiff, 602 exp.DateToDi, 603 exp.Levenshtein, 604 exp.Sign, 605 exp.StrPosition, 606 exp.TsOrDiToDi, 607 }, 608 exp.DataType.Type.JSON: { 609 exp.ParseJSON, 610 }, 611 exp.DataType.Type.TIME: { 612 exp.Time, 613 }, 614 exp.DataType.Type.TIMESTAMP: { 615 exp.CurrentTime, 616 exp.CurrentTimestamp, 617 exp.StrToTime, 618 exp.TimeAdd, 619 exp.TimeStrToTime, 620 exp.TimeSub, 621 exp.TimestampAdd, 622 exp.TimestampSub, 623 exp.UnixToTime, 624 }, 625 exp.DataType.Type.TINYINT: { 626 exp.Day, 627 exp.Month, 628 exp.Week, 629 exp.Year, 630 exp.Quarter, 631 }, 632 exp.DataType.Type.VARCHAR: { 633 exp.ArrayConcat, 634 exp.Concat, 635 exp.ConcatWs, 636 exp.DateToDateStr, 637 exp.GroupConcat, 638 exp.Initcap, 639 exp.Lower, 640 exp.Substring, 641 exp.String, 642 exp.TimeToStr, 643 exp.TimeToTimeStr, 644 exp.Trim, 645 exp.TsOrDsToDateStr, 646 exp.UnixToStr, 647 exp.UnixToTimeStr, 648 exp.Upper, 649 }, 650 } 651 652 ANNOTATORS: AnnotatorsType = { 653 **{ 654 expr_type: lambda self, e: self._annotate_unary(e) 655 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 656 }, 657 **{ 658 expr_type: lambda self, e: self._annotate_binary(e) 659 for expr_type in subclasses(exp.__name__, exp.Binary) 660 }, 661 **{ 662 expr_type: _annotate_with_type_lambda(data_type) 663 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 664 for expr_type in expressions 665 }, 666 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 667 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 668 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 669 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 670 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 671 exp.Bracket: lambda self, e: self._annotate_bracket(e), 672 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 673 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 674 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 675 exp.Count: lambda self, e: self._annotate_with_type( 676 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 677 ), 678 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 679 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 680 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 681 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 682 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 683 exp.Div: lambda self, e: self._annotate_div(e), 684 exp.Dot: lambda self, e: self._annotate_dot(e), 685 exp.Explode: lambda self, e: self._annotate_explode(e), 686 exp.Extract: lambda self, e: self._annotate_extract(e), 687 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 688 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 689 e, exp.DataType.build("ARRAY<DATE>") 690 ), 691 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 692 e, exp.DataType.build("ARRAY<TIMESTAMP>") 693 ), 694 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 695 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 696 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 697 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 698 exp.Literal: lambda self, e: self._annotate_literal(e), 699 exp.Map: lambda self, e: self._annotate_map(e), 700 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 701 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 702 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 703 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 704 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 705 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 706 exp.Struct: lambda self, e: self._annotate_struct(e), 707 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 708 exp.Timestamp: lambda self, e: self._annotate_with_type( 709 e, 710 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 711 ), 712 exp.ToMap: lambda self, e: self._annotate_to_map(e), 713 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 714 exp.Unnest: lambda self, e: self._annotate_unnest(e), 715 exp.VarMap: lambda self, e: self._annotate_map(e), 716 } 717 718 @classmethod 719 def get_or_raise(cls, dialect: DialectType) -> Dialect: 720 """ 721 Look up a dialect in the global dialect registry and return it if it exists. 722 723 Args: 724 dialect: The target dialect. If this is a string, it can be optionally followed by 725 additional key-value pairs that are separated by commas and are used to specify 726 dialect settings, such as whether the dialect's identifiers are case-sensitive. 727 728 Example: 729 >>> dialect = dialect_class = get_or_raise("duckdb") 730 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 731 732 Returns: 733 The corresponding Dialect instance. 734 """ 735 736 if not dialect: 737 return cls() 738 if isinstance(dialect, _Dialect): 739 return dialect() 740 if isinstance(dialect, Dialect): 741 return dialect 742 if isinstance(dialect, str): 743 try: 744 dialect_name, *kv_strings = dialect.split(",") 745 kv_pairs = (kv.split("=") for kv in kv_strings) 746 kwargs = {} 747 for pair in kv_pairs: 748 key = pair[0].strip() 749 value: t.Union[bool | str | None] = None 750 751 if len(pair) == 1: 752 # Default initialize standalone settings to True 753 value = True 754 elif len(pair) == 2: 755 value = pair[1].strip() 756 757 # Coerce the value to boolean if it matches to the truthy/falsy values below 758 value_lower = value.lower() 759 if value_lower in ("true", "1"): 760 value = True 761 elif value_lower in ("false", "0"): 762 value = False 763 764 kwargs[key] = value 765 766 except ValueError: 767 raise ValueError( 768 f"Invalid dialect format: '{dialect}'. " 769 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 770 ) 771 772 result = cls.get(dialect_name.strip()) 773 if not result: 774 from difflib import get_close_matches 775 776 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 777 if similar: 778 similar = f" Did you mean {similar}?" 779 780 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 781 782 return result(**kwargs) 783 784 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 785 786 @classmethod 787 def format_time( 788 cls, expression: t.Optional[str | exp.Expression] 789 ) -> t.Optional[exp.Expression]: 790 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 791 if isinstance(expression, str): 792 return exp.Literal.string( 793 # the time formats are quoted 794 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 795 ) 796 797 if expression and expression.is_string: 798 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 799 800 return expression 801 802 def __init__(self, **kwargs) -> None: 803 normalization_strategy = kwargs.pop("normalization_strategy", None) 804 805 if normalization_strategy is None: 806 self.normalization_strategy = self.NORMALIZATION_STRATEGY 807 else: 808 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 809 810 self.settings = kwargs 811 812 def __eq__(self, other: t.Any) -> bool: 813 # Does not currently take dialect state into account 814 return type(self) == other 815 816 def __hash__(self) -> int: 817 # Does not currently take dialect state into account 818 return hash(type(self)) 819 820 def normalize_identifier(self, expression: E) -> E: 821 """ 822 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 823 824 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 825 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 826 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 827 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 828 829 There are also dialects like Spark, which are case-insensitive even when quotes are 830 present, and dialects like MySQL, whose resolution rules match those employed by the 831 underlying operating system, for example they may always be case-sensitive in Linux. 832 833 Finally, the normalization behavior of some engines can even be controlled through flags, 834 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 835 836 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 837 that it can analyze queries in the optimizer and successfully capture their semantics. 838 """ 839 if ( 840 isinstance(expression, exp.Identifier) 841 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 842 and ( 843 not expression.quoted 844 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 845 ) 846 ): 847 expression.set( 848 "this", 849 ( 850 expression.this.upper() 851 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 852 else expression.this.lower() 853 ), 854 ) 855 856 return expression 857 858 def case_sensitive(self, text: str) -> bool: 859 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 860 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 861 return False 862 863 unsafe = ( 864 str.islower 865 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 866 else str.isupper 867 ) 868 return any(unsafe(char) for char in text) 869 870 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 871 """Checks if text can be identified given an identify option. 872 873 Args: 874 text: The text to check. 875 identify: 876 `"always"` or `True`: Always returns `True`. 877 `"safe"`: Only returns `True` if the identifier is case-insensitive. 878 879 Returns: 880 Whether the given text can be identified. 881 """ 882 if identify is True or identify == "always": 883 return True 884 885 if identify == "safe": 886 return not self.case_sensitive(text) 887 888 return False 889 890 def quote_identifier(self, expression: E, identify: bool = True) -> E: 891 """ 892 Adds quotes to a given identifier. 893 894 Args: 895 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 896 identify: If set to `False`, the quotes will only be added if the identifier is deemed 897 "unsafe", with respect to its characters and this dialect's normalization strategy. 898 """ 899 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 900 name = expression.this 901 expression.set( 902 "quoted", 903 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 904 ) 905 906 return expression 907 908 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 909 if isinstance(path, exp.Literal): 910 path_text = path.name 911 if path.is_number: 912 path_text = f"[{path_text}]" 913 try: 914 return parse_json_path(path_text, self) 915 except ParseError as e: 916 if self.STRICT_JSON_PATH_SYNTAX: 917 logger.warning(f"Invalid JSON path syntax. {str(e)}") 918 919 return path 920 921 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 922 return self.parser(**opts).parse(self.tokenize(sql), sql) 923 924 def parse_into( 925 self, expression_type: exp.IntoType, sql: str, **opts 926 ) -> t.List[t.Optional[exp.Expression]]: 927 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 928 929 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 930 return self.generator(**opts).generate(expression, copy=copy) 931 932 def transpile(self, sql: str, **opts) -> t.List[str]: 933 return [ 934 self.generate(expression, copy=False, **opts) if expression else "" 935 for expression in self.parse(sql) 936 ] 937 938 def tokenize(self, sql: str) -> t.List[Token]: 939 return self.tokenizer.tokenize(sql) 940 941 @property 942 def tokenizer(self) -> Tokenizer: 943 return self.tokenizer_class(dialect=self) 944 945 @property 946 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 947 return self.jsonpath_tokenizer_class(dialect=self) 948 949 def parser(self, **opts) -> Parser: 950 return self.parser_class(dialect=self, **opts) 951 952 def generator(self, **opts) -> Generator: 953 return self.generator_class(dialect=self, **opts)
802 def __init__(self, **kwargs) -> None: 803 normalization_strategy = kwargs.pop("normalization_strategy", None) 804 805 if normalization_strategy is None: 806 self.normalization_strategy = self.NORMALIZATION_STRATEGY 807 else: 808 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 809 810 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
718 @classmethod 719 def get_or_raise(cls, dialect: DialectType) -> Dialect: 720 """ 721 Look up a dialect in the global dialect registry and return it if it exists. 722 723 Args: 724 dialect: The target dialect. If this is a string, it can be optionally followed by 725 additional key-value pairs that are separated by commas and are used to specify 726 dialect settings, such as whether the dialect's identifiers are case-sensitive. 727 728 Example: 729 >>> dialect = dialect_class = get_or_raise("duckdb") 730 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 731 732 Returns: 733 The corresponding Dialect instance. 734 """ 735 736 if not dialect: 737 return cls() 738 if isinstance(dialect, _Dialect): 739 return dialect() 740 if isinstance(dialect, Dialect): 741 return dialect 742 if isinstance(dialect, str): 743 try: 744 dialect_name, *kv_strings = dialect.split(",") 745 kv_pairs = (kv.split("=") for kv in kv_strings) 746 kwargs = {} 747 for pair in kv_pairs: 748 key = pair[0].strip() 749 value: t.Union[bool | str | None] = None 750 751 if len(pair) == 1: 752 # Default initialize standalone settings to True 753 value = True 754 elif len(pair) == 2: 755 value = pair[1].strip() 756 757 # Coerce the value to boolean if it matches to the truthy/falsy values below 758 value_lower = value.lower() 759 if value_lower in ("true", "1"): 760 value = True 761 elif value_lower in ("false", "0"): 762 value = False 763 764 kwargs[key] = value 765 766 except ValueError: 767 raise ValueError( 768 f"Invalid dialect format: '{dialect}'. " 769 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 770 ) 771 772 result = cls.get(dialect_name.strip()) 773 if not result: 774 from difflib import get_close_matches 775 776 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 777 if similar: 778 similar = f" Did you mean {similar}?" 779 780 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 781 782 return result(**kwargs) 783 784 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
786 @classmethod 787 def format_time( 788 cls, expression: t.Optional[str | exp.Expression] 789 ) -> t.Optional[exp.Expression]: 790 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 791 if isinstance(expression, str): 792 return exp.Literal.string( 793 # the time formats are quoted 794 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 795 ) 796 797 if expression and expression.is_string: 798 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 799 800 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
820 def normalize_identifier(self, expression: E) -> E: 821 """ 822 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 823 824 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 825 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 826 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 827 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 828 829 There are also dialects like Spark, which are case-insensitive even when quotes are 830 present, and dialects like MySQL, whose resolution rules match those employed by the 831 underlying operating system, for example they may always be case-sensitive in Linux. 832 833 Finally, the normalization behavior of some engines can even be controlled through flags, 834 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 835 836 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 837 that it can analyze queries in the optimizer and successfully capture their semantics. 838 """ 839 if ( 840 isinstance(expression, exp.Identifier) 841 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 842 and ( 843 not expression.quoted 844 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 845 ) 846 ): 847 expression.set( 848 "this", 849 ( 850 expression.this.upper() 851 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 852 else expression.this.lower() 853 ), 854 ) 855 856 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
858 def case_sensitive(self, text: str) -> bool: 859 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 860 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 861 return False 862 863 unsafe = ( 864 str.islower 865 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 866 else str.isupper 867 ) 868 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
870 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 871 """Checks if text can be identified given an identify option. 872 873 Args: 874 text: The text to check. 875 identify: 876 `"always"` or `True`: Always returns `True`. 877 `"safe"`: Only returns `True` if the identifier is case-insensitive. 878 879 Returns: 880 Whether the given text can be identified. 881 """ 882 if identify is True or identify == "always": 883 return True 884 885 if identify == "safe": 886 return not self.case_sensitive(text) 887 888 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
890 def quote_identifier(self, expression: E, identify: bool = True) -> E: 891 """ 892 Adds quotes to a given identifier. 893 894 Args: 895 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 896 identify: If set to `False`, the quotes will only be added if the identifier is deemed 897 "unsafe", with respect to its characters and this dialect's normalization strategy. 898 """ 899 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 900 name = expression.this 901 expression.set( 902 "quoted", 903 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 904 ) 905 906 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
908 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 909 if isinstance(path, exp.Literal): 910 path_text = path.name 911 if path.is_number: 912 path_text = f"[{path_text}]" 913 try: 914 return parse_json_path(path_text, self) 915 except ParseError as e: 916 if self.STRICT_JSON_PATH_SYNTAX: 917 logger.warning(f"Invalid JSON path syntax. {str(e)}") 918 919 return path
968def if_sql( 969 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 970) -> t.Callable[[Generator, exp.If], str]: 971 def _if_sql(self: Generator, expression: exp.If) -> str: 972 return self.func( 973 name, 974 expression.this, 975 expression.args.get("true"), 976 expression.args.get("false") or false_value, 977 ) 978 979 return _if_sql
982def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 983 this = expression.this 984 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 985 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 986 987 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1057def str_position_sql( 1058 self: Generator, 1059 expression: exp.StrPosition, 1060 generate_instance: bool = False, 1061 str_position_func_name: str = "STRPOS", 1062) -> str: 1063 this = self.sql(expression, "this") 1064 substr = self.sql(expression, "substr") 1065 position = self.sql(expression, "position") 1066 instance = expression.args.get("instance") if generate_instance else None 1067 position_offset = "" 1068 1069 if position: 1070 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1071 this = self.func("SUBSTR", this, position) 1072 position_offset = f" + {position} - 1" 1073 1074 return self.func(str_position_func_name, this, substr, instance) + position_offset
1083def var_map_sql( 1084 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1085) -> str: 1086 keys = expression.args["keys"] 1087 values = expression.args["values"] 1088 1089 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1090 self.unsupported("Cannot convert array columns into map.") 1091 return self.func(map_func_name, keys, values) 1092 1093 args = [] 1094 for key, value in zip(keys.expressions, values.expressions): 1095 args.append(self.sql(key)) 1096 args.append(self.sql(value)) 1097 1098 return self.func(map_func_name, *args)
1101def build_formatted_time( 1102 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1103) -> t.Callable[[t.List], E]: 1104 """Helper used for time expressions. 1105 1106 Args: 1107 exp_class: the expression class to instantiate. 1108 dialect: target sql dialect. 1109 default: the default format, True being time. 1110 1111 Returns: 1112 A callable that can be used to return the appropriately formatted time expression. 1113 """ 1114 1115 def _builder(args: t.List): 1116 return exp_class( 1117 this=seq_get(args, 0), 1118 format=Dialect[dialect].format_time( 1119 seq_get(args, 1) 1120 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1121 ), 1122 ) 1123 1124 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1127def time_format( 1128 dialect: DialectType = None, 1129) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1130 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1131 """ 1132 Returns the time format for a given expression, unless it's equivalent 1133 to the default time format of the dialect of interest. 1134 """ 1135 time_format = self.format_time(expression) 1136 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1137 1138 return _time_format
1141def build_date_delta( 1142 exp_class: t.Type[E], 1143 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1144 default_unit: t.Optional[str] = "DAY", 1145) -> t.Callable[[t.List], E]: 1146 def _builder(args: t.List) -> E: 1147 unit_based = len(args) == 3 1148 this = args[2] if unit_based else seq_get(args, 0) 1149 unit = None 1150 if unit_based or default_unit: 1151 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1152 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1153 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1154 1155 return _builder
1158def build_date_delta_with_interval( 1159 expression_class: t.Type[E], 1160) -> t.Callable[[t.List], t.Optional[E]]: 1161 def _builder(args: t.List) -> t.Optional[E]: 1162 if len(args) < 2: 1163 return None 1164 1165 interval = args[1] 1166 1167 if not isinstance(interval, exp.Interval): 1168 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1169 1170 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1171 1172 return _builder
1175def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1176 unit = seq_get(args, 0) 1177 this = seq_get(args, 1) 1178 1179 if isinstance(this, exp.Cast) and this.is_type("date"): 1180 return exp.DateTrunc(unit=unit, this=this) 1181 return exp.TimestampTrunc(this=this, unit=unit)
1184def date_add_interval_sql( 1185 data_type: str, kind: str 1186) -> t.Callable[[Generator, exp.Expression], str]: 1187 def func(self: Generator, expression: exp.Expression) -> str: 1188 this = self.sql(expression, "this") 1189 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1190 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1191 1192 return func
1195def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1196 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1197 args = [unit_to_str(expression), expression.this] 1198 if zone: 1199 args.append(expression.args.get("zone")) 1200 return self.func("DATE_TRUNC", *args) 1201 1202 return _timestamptrunc_sql
1205def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1206 zone = expression.args.get("zone") 1207 if not zone: 1208 from sqlglot.optimizer.annotate_types import annotate_types 1209 1210 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1211 return self.sql(exp.cast(expression.this, target_type)) 1212 if zone.name.lower() in TIMEZONES: 1213 return self.sql( 1214 exp.AtTimeZone( 1215 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1216 zone=zone, 1217 ) 1218 ) 1219 return self.func("TIMESTAMP", expression.this, zone)
1222def no_time_sql(self: Generator, expression: exp.Time) -> str: 1223 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1224 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1225 expr = exp.cast( 1226 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1227 ) 1228 return self.sql(expr)
1231def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1232 this = expression.this 1233 expr = expression.expression 1234 1235 if expr.name.lower() in TIMEZONES: 1236 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1237 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1238 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1239 return self.sql(this) 1240 1241 this = exp.cast(this, exp.DataType.Type.DATE) 1242 expr = exp.cast(expr, exp.DataType.Type.TIME) 1243 1244 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1276def timestrtotime_sql( 1277 self: Generator, 1278 expression: exp.TimeStrToTime, 1279 include_precision: bool = False, 1280) -> str: 1281 datatype = exp.DataType.build( 1282 exp.DataType.Type.TIMESTAMPTZ 1283 if expression.args.get("zone") 1284 else exp.DataType.Type.TIMESTAMP 1285 ) 1286 1287 if isinstance(expression.this, exp.Literal) and include_precision: 1288 precision = subsecond_precision(expression.this.name) 1289 if precision > 0: 1290 datatype = exp.DataType.build( 1291 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1292 ) 1293 1294 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1302def encode_decode_sql( 1303 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1304) -> str: 1305 charset = expression.args.get("charset") 1306 if charset and charset.name.lower() != "utf-8": 1307 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1308 1309 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1322def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1323 cond = expression.this 1324 1325 if isinstance(expression.this, exp.Distinct): 1326 cond = expression.this.expressions[0] 1327 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1328 1329 return self.func("sum", exp.func("if", cond, 1, 0))
1332def trim_sql(self: Generator, expression: exp.Trim) -> str: 1333 target = self.sql(expression, "this") 1334 trim_type = self.sql(expression, "position") 1335 remove_chars = self.sql(expression, "expression") 1336 collation = self.sql(expression, "collation") 1337 1338 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1339 if not remove_chars: 1340 return self.trim_sql(expression) 1341 1342 trim_type = f"{trim_type} " if trim_type else "" 1343 remove_chars = f"{remove_chars} " if remove_chars else "" 1344 from_part = "FROM " if trim_type or remove_chars else "" 1345 collation = f" COLLATE {collation}" if collation else "" 1346 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1367@unsupported_args("position", "occurrence", "parameters") 1368def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1369 group = expression.args.get("group") 1370 1371 # Do not render group if it's the default value for this dialect 1372 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1373 group = None 1374 1375 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group)
1385def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1386 names = [] 1387 for agg in aggregations: 1388 if isinstance(agg, exp.Alias): 1389 names.append(agg.alias) 1390 else: 1391 """ 1392 This case corresponds to aggregations without aliases being used as suffixes 1393 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1394 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1395 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1396 """ 1397 agg_all_unquoted = agg.transform( 1398 lambda node: ( 1399 exp.Identifier(this=node.name, quoted=False) 1400 if isinstance(node, exp.Identifier) 1401 else node 1402 ) 1403 ) 1404 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1405 1406 return names
1446def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1447 @unsupported_args("count") 1448 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1449 return self.func(name, expression.this, expression.expression) 1450 1451 return _arg_max_or_min_sql
1454def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1455 this = expression.this.copy() 1456 1457 return_type = expression.return_type 1458 if return_type.is_type(exp.DataType.Type.DATE): 1459 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1460 # can truncate timestamp strings, because some dialects can't cast them to DATE 1461 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1462 1463 expression.this.replace(exp.cast(this, return_type)) 1464 return expression
1467def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1468 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1469 if cast and isinstance(expression, exp.TsOrDsAdd): 1470 expression = ts_or_ds_add_cast(expression) 1471 1472 return self.func( 1473 name, 1474 unit_to_var(expression), 1475 expression.expression, 1476 expression.this, 1477 ) 1478 1479 return _delta_sql
1482def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1483 unit = expression.args.get("unit") 1484 1485 if isinstance(unit, exp.Placeholder): 1486 return unit 1487 if unit: 1488 return exp.Literal.string(unit.name) 1489 return exp.Literal.string(default) if default else None
1519def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1520 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1521 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1522 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1523 1524 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1527def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1528 """Remove table refs from columns in when statements.""" 1529 alias = expression.this.args.get("alias") 1530 1531 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1532 return self.dialect.normalize_identifier(identifier).name if identifier else None 1533 1534 targets = {normalize(expression.this.this)} 1535 1536 if alias: 1537 targets.add(normalize(alias.this)) 1538 1539 for when in expression.expressions: 1540 # only remove the target names from the THEN clause 1541 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1542 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1543 then = when.args.get("then") 1544 if then: 1545 then.transform( 1546 lambda node: ( 1547 exp.column(node.this) 1548 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1549 else node 1550 ), 1551 copy=False, 1552 ) 1553 1554 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1557def build_json_extract_path( 1558 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1559) -> t.Callable[[t.List], F]: 1560 def _builder(args: t.List) -> F: 1561 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1562 for arg in args[1:]: 1563 if not isinstance(arg, exp.Literal): 1564 # We use the fallback parser because we can't really transpile non-literals safely 1565 return expr_type.from_arg_list(args) 1566 1567 text = arg.name 1568 if is_int(text): 1569 index = int(text) 1570 segments.append( 1571 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1572 ) 1573 else: 1574 segments.append(exp.JSONPathKey(this=text)) 1575 1576 # This is done to avoid failing in the expression validator due to the arg count 1577 del args[2:] 1578 return expr_type( 1579 this=seq_get(args, 0), 1580 expression=exp.JSONPath(expressions=segments), 1581 only_json_types=arrow_req_json_type, 1582 ) 1583 1584 return _builder
1587def json_extract_segments( 1588 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1589) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1590 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1591 path = expression.expression 1592 if not isinstance(path, exp.JSONPath): 1593 return rename_func(name)(self, expression) 1594 1595 escape = path.args.get("escape") 1596 1597 segments = [] 1598 for segment in path.expressions: 1599 path = self.sql(segment) 1600 if path: 1601 if isinstance(segment, exp.JSONPathPart) and ( 1602 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1603 ): 1604 if escape: 1605 path = self.escape_str(path) 1606 1607 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1608 1609 segments.append(path) 1610 1611 if op: 1612 return f" {op} ".join([self.sql(expression.this), *segments]) 1613 return self.func(name, expression.this, *segments) 1614 1615 return _json_extract_segments
1625def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1626 cond = expression.expression 1627 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1628 alias = cond.expressions[0] 1629 cond = cond.this 1630 elif isinstance(cond, exp.Predicate): 1631 alias = "_u" 1632 else: 1633 self.unsupported("Unsupported filter condition") 1634 return "" 1635 1636 unnest = exp.Unnest(expressions=[expression.this]) 1637 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1638 return self.sql(exp.Array(expressions=[filtered]))
1650def build_default_decimal_type( 1651 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1652) -> t.Callable[[exp.DataType], exp.DataType]: 1653 def _builder(dtype: exp.DataType) -> exp.DataType: 1654 if dtype.expressions or precision is None: 1655 return dtype 1656 1657 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1658 return exp.DataType.build(f"DECIMAL({params})") 1659 1660 return _builder
1663def build_timestamp_from_parts(args: t.List) -> exp.Func: 1664 if len(args) == 2: 1665 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1666 # so we parse this into Anonymous for now instead of introducing complexity 1667 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1668 1669 return exp.TimestampFromParts.from_arg_list(args)
1676def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1677 start = expression.args.get("start") 1678 end = expression.args.get("end") 1679 step = expression.args.get("step") 1680 1681 if isinstance(start, exp.Cast): 1682 target_type = start.to 1683 elif isinstance(end, exp.Cast): 1684 target_type = end.to 1685 else: 1686 target_type = None 1687 1688 if start and end and target_type and target_type.is_type("date", "timestamp"): 1689 if isinstance(start, exp.Cast) and target_type is start.to: 1690 end = exp.cast(end, target_type) 1691 else: 1692 start = exp.cast(start, target_type) 1693 1694 return self.func("SEQUENCE", start, end, step)
1705def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1706 if isinstance(expression.this, exp.Explode): 1707 return self.sql( 1708 exp.Join( 1709 this=exp.Unnest( 1710 expressions=[expression.this.this], 1711 alias=expression.args.get("alias"), 1712 offset=isinstance(expression.this, exp.Posexplode), 1713 ), 1714 kind="cross", 1715 ) 1716 ) 1717 return self.lateral_sql(expression)