sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 base = seq_get(bases, 0) 137 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 138 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 139 base_parser = (getattr(base, "parser_class", Parser),) 140 base_generator = (getattr(base, "generator_class", Generator),) 141 142 klass.tokenizer_class = klass.__dict__.get( 143 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 144 ) 145 klass.jsonpath_tokenizer_class = klass.__dict__.get( 146 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 147 ) 148 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 149 klass.generator_class = klass.__dict__.get( 150 "Generator", type("Generator", base_generator, {}) 151 ) 152 153 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 154 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 155 klass.tokenizer_class._IDENTIFIERS.items() 156 )[0] 157 158 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 159 return next( 160 ( 161 (s, e) 162 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 163 if t == token_type 164 ), 165 (None, None), 166 ) 167 168 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 169 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 170 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 171 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 172 173 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 174 klass.UNESCAPED_SEQUENCES = { 175 **UNESCAPED_SEQUENCES, 176 **klass.UNESCAPED_SEQUENCES, 177 } 178 179 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 180 181 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 182 183 if enum not in ("", "bigquery"): 184 klass.generator_class.SELECT_KINDS = () 185 186 if enum not in ("", "athena", "presto", "trino"): 187 klass.generator_class.TRY_SUPPORTED = False 188 klass.generator_class.SUPPORTS_UESCAPE = False 189 190 if enum not in ("", "databricks", "hive", "spark", "spark2"): 191 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 192 for modifier in ("cluster", "distribute", "sort"): 193 modifier_transforms.pop(modifier, None) 194 195 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 196 197 if enum not in ("", "doris", "mysql"): 198 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 199 TokenType.STRAIGHT_JOIN, 200 } 201 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 202 TokenType.STRAIGHT_JOIN, 203 } 204 205 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 206 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 207 TokenType.ANTI, 208 TokenType.SEMI, 209 } 210 211 return klass 212 213 214class Dialect(metaclass=_Dialect): 215 INDEX_OFFSET = 0 216 """The base index offset for arrays.""" 217 218 WEEK_OFFSET = 0 219 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 220 221 UNNEST_COLUMN_ONLY = False 222 """Whether `UNNEST` table aliases are treated as column aliases.""" 223 224 ALIAS_POST_TABLESAMPLE = False 225 """Whether the table alias comes after tablesample.""" 226 227 TABLESAMPLE_SIZE_IS_PERCENT = False 228 """Whether a size in the table sample clause represents percentage.""" 229 230 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 231 """Specifies the strategy according to which identifiers should be normalized.""" 232 233 IDENTIFIERS_CAN_START_WITH_DIGIT = False 234 """Whether an unquoted identifier can start with a digit.""" 235 236 DPIPE_IS_STRING_CONCAT = True 237 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 238 239 STRICT_STRING_CONCAT = False 240 """Whether `CONCAT`'s arguments must be strings.""" 241 242 SUPPORTS_USER_DEFINED_TYPES = True 243 """Whether user-defined data types are supported.""" 244 245 SUPPORTS_SEMI_ANTI_JOIN = True 246 """Whether `SEMI` or `ANTI` joins are supported.""" 247 248 SUPPORTS_COLUMN_JOIN_MARKS = False 249 """Whether the old-style outer join (+) syntax is supported.""" 250 251 COPY_PARAMS_ARE_CSV = True 252 """Separator of COPY statement parameters.""" 253 254 NORMALIZE_FUNCTIONS: bool | str = "upper" 255 """ 256 Determines how function names are going to be normalized. 257 Possible values: 258 "upper" or True: Convert names to uppercase. 259 "lower": Convert names to lowercase. 260 False: Disables function name normalization. 261 """ 262 263 LOG_BASE_FIRST: t.Optional[bool] = True 264 """ 265 Whether the base comes first in the `LOG` function. 266 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 267 """ 268 269 NULL_ORDERING = "nulls_are_small" 270 """ 271 Default `NULL` ordering method to use if not explicitly set. 272 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 273 """ 274 275 TYPED_DIVISION = False 276 """ 277 Whether the behavior of `a / b` depends on the types of `a` and `b`. 278 False means `a / b` is always float division. 279 True means `a / b` is integer division if both `a` and `b` are integers. 280 """ 281 282 SAFE_DIVISION = False 283 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 284 285 CONCAT_COALESCE = False 286 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 287 288 HEX_LOWERCASE = False 289 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 290 291 DATE_FORMAT = "'%Y-%m-%d'" 292 DATEINT_FORMAT = "'%Y%m%d'" 293 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 294 295 TIME_MAPPING: t.Dict[str, str] = {} 296 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 297 298 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 299 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 300 FORMAT_MAPPING: t.Dict[str, str] = {} 301 """ 302 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 303 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 304 """ 305 306 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 307 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 308 309 PSEUDOCOLUMNS: t.Set[str] = set() 310 """ 311 Columns that are auto-generated by the engine corresponding to this dialect. 312 For example, such columns may be excluded from `SELECT *` queries. 313 """ 314 315 PREFER_CTE_ALIAS_COLUMN = False 316 """ 317 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 318 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 319 any projection aliases in the subquery. 320 321 For example, 322 WITH y(c) AS ( 323 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 324 ) SELECT c FROM y; 325 326 will be rewritten as 327 328 WITH y(c) AS ( 329 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 330 ) SELECT c FROM y; 331 """ 332 333 COPY_PARAMS_ARE_CSV = True 334 """ 335 Whether COPY statement parameters are separated by comma or whitespace 336 """ 337 338 FORCE_EARLY_ALIAS_REF_EXPANSION = False 339 """ 340 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 341 342 For example: 343 WITH data AS ( 344 SELECT 345 1 AS id, 346 2 AS my_id 347 ) 348 SELECT 349 id AS my_id 350 FROM 351 data 352 WHERE 353 my_id = 1 354 GROUP BY 355 my_id, 356 HAVING 357 my_id = 1 358 359 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 360 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 361 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 362 """ 363 364 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 365 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 366 367 SUPPORTS_ORDER_BY_ALL = False 368 """ 369 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 370 """ 371 372 # --- Autofilled --- 373 374 tokenizer_class = Tokenizer 375 jsonpath_tokenizer_class = JSONPathTokenizer 376 parser_class = Parser 377 generator_class = Generator 378 379 # A trie of the time_mapping keys 380 TIME_TRIE: t.Dict = {} 381 FORMAT_TRIE: t.Dict = {} 382 383 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 384 INVERSE_TIME_TRIE: t.Dict = {} 385 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 386 INVERSE_FORMAT_TRIE: t.Dict = {} 387 388 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 389 390 # Delimiters for string literals and identifiers 391 QUOTE_START = "'" 392 QUOTE_END = "'" 393 IDENTIFIER_START = '"' 394 IDENTIFIER_END = '"' 395 396 # Delimiters for bit, hex, byte and unicode literals 397 BIT_START: t.Optional[str] = None 398 BIT_END: t.Optional[str] = None 399 HEX_START: t.Optional[str] = None 400 HEX_END: t.Optional[str] = None 401 BYTE_START: t.Optional[str] = None 402 BYTE_END: t.Optional[str] = None 403 UNICODE_START: t.Optional[str] = None 404 UNICODE_END: t.Optional[str] = None 405 406 DATE_PART_MAPPING = { 407 "Y": "YEAR", 408 "YY": "YEAR", 409 "YYY": "YEAR", 410 "YYYY": "YEAR", 411 "YR": "YEAR", 412 "YEARS": "YEAR", 413 "YRS": "YEAR", 414 "MM": "MONTH", 415 "MON": "MONTH", 416 "MONS": "MONTH", 417 "MONTHS": "MONTH", 418 "D": "DAY", 419 "DD": "DAY", 420 "DAYS": "DAY", 421 "DAYOFMONTH": "DAY", 422 "DAY OF WEEK": "DAYOFWEEK", 423 "WEEKDAY": "DAYOFWEEK", 424 "DOW": "DAYOFWEEK", 425 "DW": "DAYOFWEEK", 426 "WEEKDAY_ISO": "DAYOFWEEKISO", 427 "DOW_ISO": "DAYOFWEEKISO", 428 "DW_ISO": "DAYOFWEEKISO", 429 "DAY OF YEAR": "DAYOFYEAR", 430 "DOY": "DAYOFYEAR", 431 "DY": "DAYOFYEAR", 432 "W": "WEEK", 433 "WK": "WEEK", 434 "WEEKOFYEAR": "WEEK", 435 "WOY": "WEEK", 436 "WY": "WEEK", 437 "WEEK_ISO": "WEEKISO", 438 "WEEKOFYEARISO": "WEEKISO", 439 "WEEKOFYEAR_ISO": "WEEKISO", 440 "Q": "QUARTER", 441 "QTR": "QUARTER", 442 "QTRS": "QUARTER", 443 "QUARTERS": "QUARTER", 444 "H": "HOUR", 445 "HH": "HOUR", 446 "HR": "HOUR", 447 "HOURS": "HOUR", 448 "HRS": "HOUR", 449 "M": "MINUTE", 450 "MI": "MINUTE", 451 "MIN": "MINUTE", 452 "MINUTES": "MINUTE", 453 "MINS": "MINUTE", 454 "S": "SECOND", 455 "SEC": "SECOND", 456 "SECONDS": "SECOND", 457 "SECS": "SECOND", 458 "MS": "MILLISECOND", 459 "MSEC": "MILLISECOND", 460 "MSECS": "MILLISECOND", 461 "MSECOND": "MILLISECOND", 462 "MSECONDS": "MILLISECOND", 463 "MILLISEC": "MILLISECOND", 464 "MILLISECS": "MILLISECOND", 465 "MILLISECON": "MILLISECOND", 466 "MILLISECONDS": "MILLISECOND", 467 "US": "MICROSECOND", 468 "USEC": "MICROSECOND", 469 "USECS": "MICROSECOND", 470 "MICROSEC": "MICROSECOND", 471 "MICROSECS": "MICROSECOND", 472 "USECOND": "MICROSECOND", 473 "USECONDS": "MICROSECOND", 474 "MICROSECONDS": "MICROSECOND", 475 "NS": "NANOSECOND", 476 "NSEC": "NANOSECOND", 477 "NANOSEC": "NANOSECOND", 478 "NSECOND": "NANOSECOND", 479 "NSECONDS": "NANOSECOND", 480 "NANOSECS": "NANOSECOND", 481 "EPOCH_SECOND": "EPOCH", 482 "EPOCH_SECONDS": "EPOCH", 483 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 484 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 485 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 486 "TZH": "TIMEZONE_HOUR", 487 "TZM": "TIMEZONE_MINUTE", 488 "DEC": "DECADE", 489 "DECS": "DECADE", 490 "DECADES": "DECADE", 491 "MIL": "MILLENIUM", 492 "MILS": "MILLENIUM", 493 "MILLENIA": "MILLENIUM", 494 "C": "CENTURY", 495 "CENT": "CENTURY", 496 "CENTS": "CENTURY", 497 "CENTURIES": "CENTURY", 498 } 499 500 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 501 exp.DataType.Type.BIGINT: { 502 exp.ApproxDistinct, 503 exp.ArraySize, 504 exp.Count, 505 exp.Length, 506 }, 507 exp.DataType.Type.BOOLEAN: { 508 exp.Between, 509 exp.Boolean, 510 exp.In, 511 exp.RegexpLike, 512 }, 513 exp.DataType.Type.DATE: { 514 exp.CurrentDate, 515 exp.Date, 516 exp.DateFromParts, 517 exp.DateStrToDate, 518 exp.DiToDate, 519 exp.StrToDate, 520 exp.TimeStrToDate, 521 exp.TsOrDsToDate, 522 }, 523 exp.DataType.Type.DATETIME: { 524 exp.CurrentDatetime, 525 exp.Datetime, 526 exp.DatetimeAdd, 527 exp.DatetimeSub, 528 }, 529 exp.DataType.Type.DOUBLE: { 530 exp.ApproxQuantile, 531 exp.Avg, 532 exp.Div, 533 exp.Exp, 534 exp.Ln, 535 exp.Log, 536 exp.Pow, 537 exp.Quantile, 538 exp.Round, 539 exp.SafeDivide, 540 exp.Sqrt, 541 exp.Stddev, 542 exp.StddevPop, 543 exp.StddevSamp, 544 exp.Variance, 545 exp.VariancePop, 546 }, 547 exp.DataType.Type.INT: { 548 exp.Ceil, 549 exp.DatetimeDiff, 550 exp.DateDiff, 551 exp.TimestampDiff, 552 exp.TimeDiff, 553 exp.DateToDi, 554 exp.Levenshtein, 555 exp.Sign, 556 exp.StrPosition, 557 exp.TsOrDiToDi, 558 }, 559 exp.DataType.Type.JSON: { 560 exp.ParseJSON, 561 }, 562 exp.DataType.Type.TIME: { 563 exp.Time, 564 }, 565 exp.DataType.Type.TIMESTAMP: { 566 exp.CurrentTime, 567 exp.CurrentTimestamp, 568 exp.StrToTime, 569 exp.TimeAdd, 570 exp.TimeStrToTime, 571 exp.TimeSub, 572 exp.TimestampAdd, 573 exp.TimestampSub, 574 exp.UnixToTime, 575 }, 576 exp.DataType.Type.TINYINT: { 577 exp.Day, 578 exp.Month, 579 exp.Week, 580 exp.Year, 581 exp.Quarter, 582 }, 583 exp.DataType.Type.VARCHAR: { 584 exp.ArrayConcat, 585 exp.Concat, 586 exp.ConcatWs, 587 exp.DateToDateStr, 588 exp.GroupConcat, 589 exp.Initcap, 590 exp.Lower, 591 exp.Substring, 592 exp.TimeToStr, 593 exp.TimeToTimeStr, 594 exp.Trim, 595 exp.TsOrDsToDateStr, 596 exp.UnixToStr, 597 exp.UnixToTimeStr, 598 exp.Upper, 599 }, 600 } 601 602 ANNOTATORS: AnnotatorsType = { 603 **{ 604 expr_type: lambda self, e: self._annotate_unary(e) 605 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 606 }, 607 **{ 608 expr_type: lambda self, e: self._annotate_binary(e) 609 for expr_type in subclasses(exp.__name__, exp.Binary) 610 }, 611 **{ 612 expr_type: _annotate_with_type_lambda(data_type) 613 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 614 for expr_type in expressions 615 }, 616 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 617 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 618 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 619 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 620 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 621 exp.Bracket: lambda self, e: self._annotate_bracket(e), 622 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 623 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 624 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 625 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 626 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 627 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 628 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 629 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 630 exp.Div: lambda self, e: self._annotate_div(e), 631 exp.Dot: lambda self, e: self._annotate_dot(e), 632 exp.Explode: lambda self, e: self._annotate_explode(e), 633 exp.Extract: lambda self, e: self._annotate_extract(e), 634 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 635 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 636 e, exp.DataType.build("ARRAY<DATE>") 637 ), 638 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 639 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 640 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 641 exp.Literal: lambda self, e: self._annotate_literal(e), 642 exp.Map: lambda self, e: self._annotate_map(e), 643 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 644 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 645 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 646 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 647 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 648 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 649 exp.Struct: lambda self, e: self._annotate_struct(e), 650 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 651 exp.Timestamp: lambda self, e: self._annotate_with_type( 652 e, 653 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 654 ), 655 exp.ToMap: lambda self, e: self._annotate_to_map(e), 656 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 657 exp.Unnest: lambda self, e: self._annotate_unnest(e), 658 exp.VarMap: lambda self, e: self._annotate_map(e), 659 } 660 661 @classmethod 662 def get_or_raise(cls, dialect: DialectType) -> Dialect: 663 """ 664 Look up a dialect in the global dialect registry and return it if it exists. 665 666 Args: 667 dialect: The target dialect. If this is a string, it can be optionally followed by 668 additional key-value pairs that are separated by commas and are used to specify 669 dialect settings, such as whether the dialect's identifiers are case-sensitive. 670 671 Example: 672 >>> dialect = dialect_class = get_or_raise("duckdb") 673 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 674 675 Returns: 676 The corresponding Dialect instance. 677 """ 678 679 if not dialect: 680 return cls() 681 if isinstance(dialect, _Dialect): 682 return dialect() 683 if isinstance(dialect, Dialect): 684 return dialect 685 if isinstance(dialect, str): 686 try: 687 dialect_name, *kv_strings = dialect.split(",") 688 kv_pairs = (kv.split("=") for kv in kv_strings) 689 kwargs = {} 690 for pair in kv_pairs: 691 key = pair[0].strip() 692 value: t.Union[bool | str | None] = None 693 694 if len(pair) == 1: 695 # Default initialize standalone settings to True 696 value = True 697 elif len(pair) == 2: 698 value = pair[1].strip() 699 700 # Coerce the value to boolean if it matches to the truthy/falsy values below 701 value_lower = value.lower() 702 if value_lower in ("true", "1"): 703 value = True 704 elif value_lower in ("false", "0"): 705 value = False 706 707 kwargs[key] = value 708 709 except ValueError: 710 raise ValueError( 711 f"Invalid dialect format: '{dialect}'. " 712 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 713 ) 714 715 result = cls.get(dialect_name.strip()) 716 if not result: 717 from difflib import get_close_matches 718 719 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 720 if similar: 721 similar = f" Did you mean {similar}?" 722 723 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 724 725 return result(**kwargs) 726 727 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 728 729 @classmethod 730 def format_time( 731 cls, expression: t.Optional[str | exp.Expression] 732 ) -> t.Optional[exp.Expression]: 733 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 734 if isinstance(expression, str): 735 return exp.Literal.string( 736 # the time formats are quoted 737 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 738 ) 739 740 if expression and expression.is_string: 741 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 742 743 return expression 744 745 def __init__(self, **kwargs) -> None: 746 normalization_strategy = kwargs.pop("normalization_strategy", None) 747 748 if normalization_strategy is None: 749 self.normalization_strategy = self.NORMALIZATION_STRATEGY 750 else: 751 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 752 753 self.settings = kwargs 754 755 def __eq__(self, other: t.Any) -> bool: 756 # Does not currently take dialect state into account 757 return type(self) == other 758 759 def __hash__(self) -> int: 760 # Does not currently take dialect state into account 761 return hash(type(self)) 762 763 def normalize_identifier(self, expression: E) -> E: 764 """ 765 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 766 767 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 768 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 769 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 770 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 771 772 There are also dialects like Spark, which are case-insensitive even when quotes are 773 present, and dialects like MySQL, whose resolution rules match those employed by the 774 underlying operating system, for example they may always be case-sensitive in Linux. 775 776 Finally, the normalization behavior of some engines can even be controlled through flags, 777 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 778 779 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 780 that it can analyze queries in the optimizer and successfully capture their semantics. 781 """ 782 if ( 783 isinstance(expression, exp.Identifier) 784 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 785 and ( 786 not expression.quoted 787 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 788 ) 789 ): 790 expression.set( 791 "this", 792 ( 793 expression.this.upper() 794 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 795 else expression.this.lower() 796 ), 797 ) 798 799 return expression 800 801 def case_sensitive(self, text: str) -> bool: 802 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 803 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 804 return False 805 806 unsafe = ( 807 str.islower 808 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 809 else str.isupper 810 ) 811 return any(unsafe(char) for char in text) 812 813 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 814 """Checks if text can be identified given an identify option. 815 816 Args: 817 text: The text to check. 818 identify: 819 `"always"` or `True`: Always returns `True`. 820 `"safe"`: Only returns `True` if the identifier is case-insensitive. 821 822 Returns: 823 Whether the given text can be identified. 824 """ 825 if identify is True or identify == "always": 826 return True 827 828 if identify == "safe": 829 return not self.case_sensitive(text) 830 831 return False 832 833 def quote_identifier(self, expression: E, identify: bool = True) -> E: 834 """ 835 Adds quotes to a given identifier. 836 837 Args: 838 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 839 identify: If set to `False`, the quotes will only be added if the identifier is deemed 840 "unsafe", with respect to its characters and this dialect's normalization strategy. 841 """ 842 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 843 name = expression.this 844 expression.set( 845 "quoted", 846 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 847 ) 848 849 return expression 850 851 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 852 if isinstance(path, exp.Literal): 853 path_text = path.name 854 if path.is_number: 855 path_text = f"[{path_text}]" 856 try: 857 return parse_json_path(path_text, self) 858 except ParseError as e: 859 logger.warning(f"Invalid JSON path syntax. {str(e)}") 860 861 return path 862 863 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 864 return self.parser(**opts).parse(self.tokenize(sql), sql) 865 866 def parse_into( 867 self, expression_type: exp.IntoType, sql: str, **opts 868 ) -> t.List[t.Optional[exp.Expression]]: 869 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 870 871 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 872 return self.generator(**opts).generate(expression, copy=copy) 873 874 def transpile(self, sql: str, **opts) -> t.List[str]: 875 return [ 876 self.generate(expression, copy=False, **opts) if expression else "" 877 for expression in self.parse(sql) 878 ] 879 880 def tokenize(self, sql: str) -> t.List[Token]: 881 return self.tokenizer.tokenize(sql) 882 883 @property 884 def tokenizer(self) -> Tokenizer: 885 return self.tokenizer_class(dialect=self) 886 887 @property 888 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 889 return self.jsonpath_tokenizer_class(dialect=self) 890 891 def parser(self, **opts) -> Parser: 892 return self.parser_class(dialect=self, **opts) 893 894 def generator(self, **opts) -> Generator: 895 return self.generator_class(dialect=self, **opts) 896 897 898DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 899 900 901def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 902 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 903 904 905def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 906 if expression.args.get("accuracy"): 907 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 908 return self.func("APPROX_COUNT_DISTINCT", expression.this) 909 910 911def if_sql( 912 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 913) -> t.Callable[[Generator, exp.If], str]: 914 def _if_sql(self: Generator, expression: exp.If) -> str: 915 return self.func( 916 name, 917 expression.this, 918 expression.args.get("true"), 919 expression.args.get("false") or false_value, 920 ) 921 922 return _if_sql 923 924 925def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 926 this = expression.this 927 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 928 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 929 930 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 931 932 933def inline_array_sql(self: Generator, expression: exp.Array) -> str: 934 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 935 936 937def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 938 elem = seq_get(expression.expressions, 0) 939 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 940 return self.func("ARRAY", elem) 941 return inline_array_sql(self, expression) 942 943 944def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 945 return self.like_sql( 946 exp.Like( 947 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 948 ) 949 ) 950 951 952def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 953 zone = self.sql(expression, "this") 954 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 955 956 957def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 958 if expression.args.get("recursive"): 959 self.unsupported("Recursive CTEs are unsupported") 960 expression.args["recursive"] = False 961 return self.with_sql(expression) 962 963 964def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 965 n = self.sql(expression, "this") 966 d = self.sql(expression, "expression") 967 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 968 969 970def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 971 self.unsupported("TABLESAMPLE unsupported") 972 return self.sql(expression.this) 973 974 975def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 976 self.unsupported("PIVOT unsupported") 977 return "" 978 979 980def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 981 return self.cast_sql(expression) 982 983 984def no_comment_column_constraint_sql( 985 self: Generator, expression: exp.CommentColumnConstraint 986) -> str: 987 self.unsupported("CommentColumnConstraint unsupported") 988 return "" 989 990 991def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 992 self.unsupported("MAP_FROM_ENTRIES unsupported") 993 return "" 994 995 996def str_position_sql( 997 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 998) -> str: 999 this = self.sql(expression, "this") 1000 substr = self.sql(expression, "substr") 1001 position = self.sql(expression, "position") 1002 instance = expression.args.get("instance") if generate_instance else None 1003 position_offset = "" 1004 1005 if position: 1006 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1007 this = self.func("SUBSTR", this, position) 1008 position_offset = f" + {position} - 1" 1009 1010 return self.func("STRPOS", this, substr, instance) + position_offset 1011 1012 1013def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1014 return ( 1015 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1016 ) 1017 1018 1019def var_map_sql( 1020 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1021) -> str: 1022 keys = expression.args["keys"] 1023 values = expression.args["values"] 1024 1025 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1026 self.unsupported("Cannot convert array columns into map.") 1027 return self.func(map_func_name, keys, values) 1028 1029 args = [] 1030 for key, value in zip(keys.expressions, values.expressions): 1031 args.append(self.sql(key)) 1032 args.append(self.sql(value)) 1033 1034 return self.func(map_func_name, *args) 1035 1036 1037def build_formatted_time( 1038 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1039) -> t.Callable[[t.List], E]: 1040 """Helper used for time expressions. 1041 1042 Args: 1043 exp_class: the expression class to instantiate. 1044 dialect: target sql dialect. 1045 default: the default format, True being time. 1046 1047 Returns: 1048 A callable that can be used to return the appropriately formatted time expression. 1049 """ 1050 1051 def _builder(args: t.List): 1052 return exp_class( 1053 this=seq_get(args, 0), 1054 format=Dialect[dialect].format_time( 1055 seq_get(args, 1) 1056 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1057 ), 1058 ) 1059 1060 return _builder 1061 1062 1063def time_format( 1064 dialect: DialectType = None, 1065) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1066 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1067 """ 1068 Returns the time format for a given expression, unless it's equivalent 1069 to the default time format of the dialect of interest. 1070 """ 1071 time_format = self.format_time(expression) 1072 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1073 1074 return _time_format 1075 1076 1077def build_date_delta( 1078 exp_class: t.Type[E], 1079 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1080 default_unit: t.Optional[str] = "DAY", 1081) -> t.Callable[[t.List], E]: 1082 def _builder(args: t.List) -> E: 1083 unit_based = len(args) == 3 1084 this = args[2] if unit_based else seq_get(args, 0) 1085 unit = None 1086 if unit_based or default_unit: 1087 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1088 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1089 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1090 1091 return _builder 1092 1093 1094def build_date_delta_with_interval( 1095 expression_class: t.Type[E], 1096) -> t.Callable[[t.List], t.Optional[E]]: 1097 def _builder(args: t.List) -> t.Optional[E]: 1098 if len(args) < 2: 1099 return None 1100 1101 interval = args[1] 1102 1103 if not isinstance(interval, exp.Interval): 1104 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1105 1106 expression = interval.this 1107 if expression and expression.is_string: 1108 expression = exp.Literal.number(expression.this) 1109 1110 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1111 1112 return _builder 1113 1114 1115def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1116 unit = seq_get(args, 0) 1117 this = seq_get(args, 1) 1118 1119 if isinstance(this, exp.Cast) and this.is_type("date"): 1120 return exp.DateTrunc(unit=unit, this=this) 1121 return exp.TimestampTrunc(this=this, unit=unit) 1122 1123 1124def date_add_interval_sql( 1125 data_type: str, kind: str 1126) -> t.Callable[[Generator, exp.Expression], str]: 1127 def func(self: Generator, expression: exp.Expression) -> str: 1128 this = self.sql(expression, "this") 1129 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1130 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1131 1132 return func 1133 1134 1135def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1136 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1137 args = [unit_to_str(expression), expression.this] 1138 if zone: 1139 args.append(expression.args.get("zone")) 1140 return self.func("DATE_TRUNC", *args) 1141 1142 return _timestamptrunc_sql 1143 1144 1145def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1146 zone = expression.args.get("zone") 1147 if not zone: 1148 from sqlglot.optimizer.annotate_types import annotate_types 1149 1150 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1151 return self.sql(exp.cast(expression.this, target_type)) 1152 if zone.name.lower() in TIMEZONES: 1153 return self.sql( 1154 exp.AtTimeZone( 1155 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1156 zone=zone, 1157 ) 1158 ) 1159 return self.func("TIMESTAMP", expression.this, zone) 1160 1161 1162def no_time_sql(self: Generator, expression: exp.Time) -> str: 1163 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1164 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1165 expr = exp.cast( 1166 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1167 ) 1168 return self.sql(expr) 1169 1170 1171def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1172 this = expression.this 1173 expr = expression.expression 1174 1175 if expr.name.lower() in TIMEZONES: 1176 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1177 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1178 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1179 return self.sql(this) 1180 1181 this = exp.cast(this, exp.DataType.Type.DATE) 1182 expr = exp.cast(expr, exp.DataType.Type.TIME) 1183 1184 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1185 1186 1187def locate_to_strposition(args: t.List) -> exp.Expression: 1188 return exp.StrPosition( 1189 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1190 ) 1191 1192 1193def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1194 return self.func( 1195 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1196 ) 1197 1198 1199def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1200 return self.sql( 1201 exp.Substring( 1202 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1203 ) 1204 ) 1205 1206 1207def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1208 return self.sql( 1209 exp.Substring( 1210 this=expression.this, 1211 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1212 ) 1213 ) 1214 1215 1216def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1217 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1218 1219 1220def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1221 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1222 1223 1224# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1225def encode_decode_sql( 1226 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1227) -> str: 1228 charset = expression.args.get("charset") 1229 if charset and charset.name.lower() != "utf-8": 1230 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1231 1232 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1233 1234 1235def min_or_least(self: Generator, expression: exp.Min) -> str: 1236 name = "LEAST" if expression.expressions else "MIN" 1237 return rename_func(name)(self, expression) 1238 1239 1240def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1241 name = "GREATEST" if expression.expressions else "MAX" 1242 return rename_func(name)(self, expression) 1243 1244 1245def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1246 cond = expression.this 1247 1248 if isinstance(expression.this, exp.Distinct): 1249 cond = expression.this.expressions[0] 1250 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1251 1252 return self.func("sum", exp.func("if", cond, 1, 0)) 1253 1254 1255def trim_sql(self: Generator, expression: exp.Trim) -> str: 1256 target = self.sql(expression, "this") 1257 trim_type = self.sql(expression, "position") 1258 remove_chars = self.sql(expression, "expression") 1259 collation = self.sql(expression, "collation") 1260 1261 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1262 if not remove_chars and not collation: 1263 return self.trim_sql(expression) 1264 1265 trim_type = f"{trim_type} " if trim_type else "" 1266 remove_chars = f"{remove_chars} " if remove_chars else "" 1267 from_part = "FROM " if trim_type or remove_chars else "" 1268 collation = f" COLLATE {collation}" if collation else "" 1269 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1270 1271 1272def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1273 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1274 1275 1276def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1277 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1278 1279 1280def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1281 delim, *rest_args = expression.expressions 1282 return self.sql( 1283 reduce( 1284 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1285 rest_args, 1286 ) 1287 ) 1288 1289 1290def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1291 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1292 if bad_args: 1293 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1294 1295 return self.func( 1296 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1297 ) 1298 1299 1300def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1301 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1302 if bad_args: 1303 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1304 1305 return self.func( 1306 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1307 ) 1308 1309 1310def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1311 names = [] 1312 for agg in aggregations: 1313 if isinstance(agg, exp.Alias): 1314 names.append(agg.alias) 1315 else: 1316 """ 1317 This case corresponds to aggregations without aliases being used as suffixes 1318 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1319 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1320 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1321 """ 1322 agg_all_unquoted = agg.transform( 1323 lambda node: ( 1324 exp.Identifier(this=node.name, quoted=False) 1325 if isinstance(node, exp.Identifier) 1326 else node 1327 ) 1328 ) 1329 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1330 1331 return names 1332 1333 1334def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1335 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1336 1337 1338# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1339def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1340 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1341 1342 1343def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1344 return self.func("MAX", expression.this) 1345 1346 1347def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1348 a = self.sql(expression.left) 1349 b = self.sql(expression.right) 1350 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1351 1352 1353def is_parse_json(expression: exp.Expression) -> bool: 1354 return isinstance(expression, exp.ParseJSON) or ( 1355 isinstance(expression, exp.Cast) and expression.is_type("json") 1356 ) 1357 1358 1359def isnull_to_is_null(args: t.List) -> exp.Expression: 1360 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1361 1362 1363def generatedasidentitycolumnconstraint_sql( 1364 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1365) -> str: 1366 start = self.sql(expression, "start") or "1" 1367 increment = self.sql(expression, "increment") or "1" 1368 return f"IDENTITY({start}, {increment})" 1369 1370 1371def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1372 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1373 if expression.args.get("count"): 1374 self.unsupported(f"Only two arguments are supported in function {name}.") 1375 1376 return self.func(name, expression.this, expression.expression) 1377 1378 return _arg_max_or_min_sql 1379 1380 1381def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1382 this = expression.this.copy() 1383 1384 return_type = expression.return_type 1385 if return_type.is_type(exp.DataType.Type.DATE): 1386 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1387 # can truncate timestamp strings, because some dialects can't cast them to DATE 1388 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1389 1390 expression.this.replace(exp.cast(this, return_type)) 1391 return expression 1392 1393 1394def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1395 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1396 if cast and isinstance(expression, exp.TsOrDsAdd): 1397 expression = ts_or_ds_add_cast(expression) 1398 1399 return self.func( 1400 name, 1401 unit_to_var(expression), 1402 expression.expression, 1403 expression.this, 1404 ) 1405 1406 return _delta_sql 1407 1408 1409def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1410 unit = expression.args.get("unit") 1411 1412 if isinstance(unit, exp.Placeholder): 1413 return unit 1414 if unit: 1415 return exp.Literal.string(unit.name) 1416 return exp.Literal.string(default) if default else None 1417 1418 1419def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1420 unit = expression.args.get("unit") 1421 1422 if isinstance(unit, (exp.Var, exp.Placeholder)): 1423 return unit 1424 return exp.Var(this=default) if default else None 1425 1426 1427@t.overload 1428def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1429 pass 1430 1431 1432@t.overload 1433def map_date_part( 1434 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1435) -> t.Optional[exp.Expression]: 1436 pass 1437 1438 1439def map_date_part(part, dialect: DialectType = Dialect): 1440 mapped = ( 1441 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1442 ) 1443 return exp.var(mapped) if mapped else part 1444 1445 1446def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1447 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1448 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1449 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1450 1451 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1452 1453 1454def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1455 """Remove table refs from columns in when statements.""" 1456 alias = expression.this.args.get("alias") 1457 1458 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1459 return self.dialect.normalize_identifier(identifier).name if identifier else None 1460 1461 targets = {normalize(expression.this.this)} 1462 1463 if alias: 1464 targets.add(normalize(alias.this)) 1465 1466 for when in expression.expressions: 1467 when.transform( 1468 lambda node: ( 1469 exp.column(node.this) 1470 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1471 else node 1472 ), 1473 copy=False, 1474 ) 1475 1476 return self.merge_sql(expression) 1477 1478 1479def build_json_extract_path( 1480 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1481) -> t.Callable[[t.List], F]: 1482 def _builder(args: t.List) -> F: 1483 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1484 for arg in args[1:]: 1485 if not isinstance(arg, exp.Literal): 1486 # We use the fallback parser because we can't really transpile non-literals safely 1487 return expr_type.from_arg_list(args) 1488 1489 text = arg.name 1490 if is_int(text): 1491 index = int(text) 1492 segments.append( 1493 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1494 ) 1495 else: 1496 segments.append(exp.JSONPathKey(this=text)) 1497 1498 # This is done to avoid failing in the expression validator due to the arg count 1499 del args[2:] 1500 return expr_type( 1501 this=seq_get(args, 0), 1502 expression=exp.JSONPath(expressions=segments), 1503 only_json_types=arrow_req_json_type, 1504 ) 1505 1506 return _builder 1507 1508 1509def json_extract_segments( 1510 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1511) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1512 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1513 path = expression.expression 1514 if not isinstance(path, exp.JSONPath): 1515 return rename_func(name)(self, expression) 1516 1517 segments = [] 1518 for segment in path.expressions: 1519 path = self.sql(segment) 1520 if path: 1521 if isinstance(segment, exp.JSONPathPart) and ( 1522 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1523 ): 1524 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1525 1526 segments.append(path) 1527 1528 if op: 1529 return f" {op} ".join([self.sql(expression.this), *segments]) 1530 return self.func(name, expression.this, *segments) 1531 1532 return _json_extract_segments 1533 1534 1535def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1536 if isinstance(expression.this, exp.JSONPathWildcard): 1537 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1538 1539 return expression.name 1540 1541 1542def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1543 cond = expression.expression 1544 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1545 alias = cond.expressions[0] 1546 cond = cond.this 1547 elif isinstance(cond, exp.Predicate): 1548 alias = "_u" 1549 else: 1550 self.unsupported("Unsupported filter condition") 1551 return "" 1552 1553 unnest = exp.Unnest(expressions=[expression.this]) 1554 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1555 return self.sql(exp.Array(expressions=[filtered])) 1556 1557 1558def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1559 return self.func( 1560 "TO_NUMBER", 1561 expression.this, 1562 expression.args.get("format"), 1563 expression.args.get("nlsparam"), 1564 ) 1565 1566 1567def build_default_decimal_type( 1568 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1569) -> t.Callable[[exp.DataType], exp.DataType]: 1570 def _builder(dtype: exp.DataType) -> exp.DataType: 1571 if dtype.expressions or precision is None: 1572 return dtype 1573 1574 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1575 return exp.DataType.build(f"DECIMAL({params})") 1576 1577 return _builder 1578 1579 1580def build_timestamp_from_parts(args: t.List) -> exp.Func: 1581 if len(args) == 2: 1582 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1583 # so we parse this into Anonymous for now instead of introducing complexity 1584 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1585 1586 return exp.TimestampFromParts.from_arg_list(args) 1587 1588 1589def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1590 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1591 1592 1593def sequence_sql(self: Generator, expression: exp.GenerateSeries): 1594 start = expression.args["start"] 1595 end = expression.args["end"] 1596 step = expression.args.get("step") 1597 1598 if isinstance(start, exp.Cast): 1599 target_type = start.to 1600 elif isinstance(end, exp.Cast): 1601 target_type = end.to 1602 else: 1603 target_type = None 1604 1605 if target_type and target_type.is_type("timestamp"): 1606 if target_type is start.to: 1607 end = exp.cast(end, target_type) 1608 else: 1609 start = exp.cast(start, target_type) 1610 1611 return self.func("SEQUENCE", start, end, step)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
215class Dialect(metaclass=_Dialect): 216 INDEX_OFFSET = 0 217 """The base index offset for arrays.""" 218 219 WEEK_OFFSET = 0 220 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 221 222 UNNEST_COLUMN_ONLY = False 223 """Whether `UNNEST` table aliases are treated as column aliases.""" 224 225 ALIAS_POST_TABLESAMPLE = False 226 """Whether the table alias comes after tablesample.""" 227 228 TABLESAMPLE_SIZE_IS_PERCENT = False 229 """Whether a size in the table sample clause represents percentage.""" 230 231 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 232 """Specifies the strategy according to which identifiers should be normalized.""" 233 234 IDENTIFIERS_CAN_START_WITH_DIGIT = False 235 """Whether an unquoted identifier can start with a digit.""" 236 237 DPIPE_IS_STRING_CONCAT = True 238 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 239 240 STRICT_STRING_CONCAT = False 241 """Whether `CONCAT`'s arguments must be strings.""" 242 243 SUPPORTS_USER_DEFINED_TYPES = True 244 """Whether user-defined data types are supported.""" 245 246 SUPPORTS_SEMI_ANTI_JOIN = True 247 """Whether `SEMI` or `ANTI` joins are supported.""" 248 249 SUPPORTS_COLUMN_JOIN_MARKS = False 250 """Whether the old-style outer join (+) syntax is supported.""" 251 252 COPY_PARAMS_ARE_CSV = True 253 """Separator of COPY statement parameters.""" 254 255 NORMALIZE_FUNCTIONS: bool | str = "upper" 256 """ 257 Determines how function names are going to be normalized. 258 Possible values: 259 "upper" or True: Convert names to uppercase. 260 "lower": Convert names to lowercase. 261 False: Disables function name normalization. 262 """ 263 264 LOG_BASE_FIRST: t.Optional[bool] = True 265 """ 266 Whether the base comes first in the `LOG` function. 267 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 268 """ 269 270 NULL_ORDERING = "nulls_are_small" 271 """ 272 Default `NULL` ordering method to use if not explicitly set. 273 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 274 """ 275 276 TYPED_DIVISION = False 277 """ 278 Whether the behavior of `a / b` depends on the types of `a` and `b`. 279 False means `a / b` is always float division. 280 True means `a / b` is integer division if both `a` and `b` are integers. 281 """ 282 283 SAFE_DIVISION = False 284 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 285 286 CONCAT_COALESCE = False 287 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 288 289 HEX_LOWERCASE = False 290 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 291 292 DATE_FORMAT = "'%Y-%m-%d'" 293 DATEINT_FORMAT = "'%Y%m%d'" 294 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 295 296 TIME_MAPPING: t.Dict[str, str] = {} 297 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 298 299 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 300 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 301 FORMAT_MAPPING: t.Dict[str, str] = {} 302 """ 303 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 304 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 305 """ 306 307 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 308 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 309 310 PSEUDOCOLUMNS: t.Set[str] = set() 311 """ 312 Columns that are auto-generated by the engine corresponding to this dialect. 313 For example, such columns may be excluded from `SELECT *` queries. 314 """ 315 316 PREFER_CTE_ALIAS_COLUMN = False 317 """ 318 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 319 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 320 any projection aliases in the subquery. 321 322 For example, 323 WITH y(c) AS ( 324 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 325 ) SELECT c FROM y; 326 327 will be rewritten as 328 329 WITH y(c) AS ( 330 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 331 ) SELECT c FROM y; 332 """ 333 334 COPY_PARAMS_ARE_CSV = True 335 """ 336 Whether COPY statement parameters are separated by comma or whitespace 337 """ 338 339 FORCE_EARLY_ALIAS_REF_EXPANSION = False 340 """ 341 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 342 343 For example: 344 WITH data AS ( 345 SELECT 346 1 AS id, 347 2 AS my_id 348 ) 349 SELECT 350 id AS my_id 351 FROM 352 data 353 WHERE 354 my_id = 1 355 GROUP BY 356 my_id, 357 HAVING 358 my_id = 1 359 360 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 361 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 362 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 363 """ 364 365 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 366 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 367 368 SUPPORTS_ORDER_BY_ALL = False 369 """ 370 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 371 """ 372 373 # --- Autofilled --- 374 375 tokenizer_class = Tokenizer 376 jsonpath_tokenizer_class = JSONPathTokenizer 377 parser_class = Parser 378 generator_class = Generator 379 380 # A trie of the time_mapping keys 381 TIME_TRIE: t.Dict = {} 382 FORMAT_TRIE: t.Dict = {} 383 384 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 385 INVERSE_TIME_TRIE: t.Dict = {} 386 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 387 INVERSE_FORMAT_TRIE: t.Dict = {} 388 389 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 390 391 # Delimiters for string literals and identifiers 392 QUOTE_START = "'" 393 QUOTE_END = "'" 394 IDENTIFIER_START = '"' 395 IDENTIFIER_END = '"' 396 397 # Delimiters for bit, hex, byte and unicode literals 398 BIT_START: t.Optional[str] = None 399 BIT_END: t.Optional[str] = None 400 HEX_START: t.Optional[str] = None 401 HEX_END: t.Optional[str] = None 402 BYTE_START: t.Optional[str] = None 403 BYTE_END: t.Optional[str] = None 404 UNICODE_START: t.Optional[str] = None 405 UNICODE_END: t.Optional[str] = None 406 407 DATE_PART_MAPPING = { 408 "Y": "YEAR", 409 "YY": "YEAR", 410 "YYY": "YEAR", 411 "YYYY": "YEAR", 412 "YR": "YEAR", 413 "YEARS": "YEAR", 414 "YRS": "YEAR", 415 "MM": "MONTH", 416 "MON": "MONTH", 417 "MONS": "MONTH", 418 "MONTHS": "MONTH", 419 "D": "DAY", 420 "DD": "DAY", 421 "DAYS": "DAY", 422 "DAYOFMONTH": "DAY", 423 "DAY OF WEEK": "DAYOFWEEK", 424 "WEEKDAY": "DAYOFWEEK", 425 "DOW": "DAYOFWEEK", 426 "DW": "DAYOFWEEK", 427 "WEEKDAY_ISO": "DAYOFWEEKISO", 428 "DOW_ISO": "DAYOFWEEKISO", 429 "DW_ISO": "DAYOFWEEKISO", 430 "DAY OF YEAR": "DAYOFYEAR", 431 "DOY": "DAYOFYEAR", 432 "DY": "DAYOFYEAR", 433 "W": "WEEK", 434 "WK": "WEEK", 435 "WEEKOFYEAR": "WEEK", 436 "WOY": "WEEK", 437 "WY": "WEEK", 438 "WEEK_ISO": "WEEKISO", 439 "WEEKOFYEARISO": "WEEKISO", 440 "WEEKOFYEAR_ISO": "WEEKISO", 441 "Q": "QUARTER", 442 "QTR": "QUARTER", 443 "QTRS": "QUARTER", 444 "QUARTERS": "QUARTER", 445 "H": "HOUR", 446 "HH": "HOUR", 447 "HR": "HOUR", 448 "HOURS": "HOUR", 449 "HRS": "HOUR", 450 "M": "MINUTE", 451 "MI": "MINUTE", 452 "MIN": "MINUTE", 453 "MINUTES": "MINUTE", 454 "MINS": "MINUTE", 455 "S": "SECOND", 456 "SEC": "SECOND", 457 "SECONDS": "SECOND", 458 "SECS": "SECOND", 459 "MS": "MILLISECOND", 460 "MSEC": "MILLISECOND", 461 "MSECS": "MILLISECOND", 462 "MSECOND": "MILLISECOND", 463 "MSECONDS": "MILLISECOND", 464 "MILLISEC": "MILLISECOND", 465 "MILLISECS": "MILLISECOND", 466 "MILLISECON": "MILLISECOND", 467 "MILLISECONDS": "MILLISECOND", 468 "US": "MICROSECOND", 469 "USEC": "MICROSECOND", 470 "USECS": "MICROSECOND", 471 "MICROSEC": "MICROSECOND", 472 "MICROSECS": "MICROSECOND", 473 "USECOND": "MICROSECOND", 474 "USECONDS": "MICROSECOND", 475 "MICROSECONDS": "MICROSECOND", 476 "NS": "NANOSECOND", 477 "NSEC": "NANOSECOND", 478 "NANOSEC": "NANOSECOND", 479 "NSECOND": "NANOSECOND", 480 "NSECONDS": "NANOSECOND", 481 "NANOSECS": "NANOSECOND", 482 "EPOCH_SECOND": "EPOCH", 483 "EPOCH_SECONDS": "EPOCH", 484 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 485 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 486 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 487 "TZH": "TIMEZONE_HOUR", 488 "TZM": "TIMEZONE_MINUTE", 489 "DEC": "DECADE", 490 "DECS": "DECADE", 491 "DECADES": "DECADE", 492 "MIL": "MILLENIUM", 493 "MILS": "MILLENIUM", 494 "MILLENIA": "MILLENIUM", 495 "C": "CENTURY", 496 "CENT": "CENTURY", 497 "CENTS": "CENTURY", 498 "CENTURIES": "CENTURY", 499 } 500 501 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 502 exp.DataType.Type.BIGINT: { 503 exp.ApproxDistinct, 504 exp.ArraySize, 505 exp.Count, 506 exp.Length, 507 }, 508 exp.DataType.Type.BOOLEAN: { 509 exp.Between, 510 exp.Boolean, 511 exp.In, 512 exp.RegexpLike, 513 }, 514 exp.DataType.Type.DATE: { 515 exp.CurrentDate, 516 exp.Date, 517 exp.DateFromParts, 518 exp.DateStrToDate, 519 exp.DiToDate, 520 exp.StrToDate, 521 exp.TimeStrToDate, 522 exp.TsOrDsToDate, 523 }, 524 exp.DataType.Type.DATETIME: { 525 exp.CurrentDatetime, 526 exp.Datetime, 527 exp.DatetimeAdd, 528 exp.DatetimeSub, 529 }, 530 exp.DataType.Type.DOUBLE: { 531 exp.ApproxQuantile, 532 exp.Avg, 533 exp.Div, 534 exp.Exp, 535 exp.Ln, 536 exp.Log, 537 exp.Pow, 538 exp.Quantile, 539 exp.Round, 540 exp.SafeDivide, 541 exp.Sqrt, 542 exp.Stddev, 543 exp.StddevPop, 544 exp.StddevSamp, 545 exp.Variance, 546 exp.VariancePop, 547 }, 548 exp.DataType.Type.INT: { 549 exp.Ceil, 550 exp.DatetimeDiff, 551 exp.DateDiff, 552 exp.TimestampDiff, 553 exp.TimeDiff, 554 exp.DateToDi, 555 exp.Levenshtein, 556 exp.Sign, 557 exp.StrPosition, 558 exp.TsOrDiToDi, 559 }, 560 exp.DataType.Type.JSON: { 561 exp.ParseJSON, 562 }, 563 exp.DataType.Type.TIME: { 564 exp.Time, 565 }, 566 exp.DataType.Type.TIMESTAMP: { 567 exp.CurrentTime, 568 exp.CurrentTimestamp, 569 exp.StrToTime, 570 exp.TimeAdd, 571 exp.TimeStrToTime, 572 exp.TimeSub, 573 exp.TimestampAdd, 574 exp.TimestampSub, 575 exp.UnixToTime, 576 }, 577 exp.DataType.Type.TINYINT: { 578 exp.Day, 579 exp.Month, 580 exp.Week, 581 exp.Year, 582 exp.Quarter, 583 }, 584 exp.DataType.Type.VARCHAR: { 585 exp.ArrayConcat, 586 exp.Concat, 587 exp.ConcatWs, 588 exp.DateToDateStr, 589 exp.GroupConcat, 590 exp.Initcap, 591 exp.Lower, 592 exp.Substring, 593 exp.TimeToStr, 594 exp.TimeToTimeStr, 595 exp.Trim, 596 exp.TsOrDsToDateStr, 597 exp.UnixToStr, 598 exp.UnixToTimeStr, 599 exp.Upper, 600 }, 601 } 602 603 ANNOTATORS: AnnotatorsType = { 604 **{ 605 expr_type: lambda self, e: self._annotate_unary(e) 606 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 607 }, 608 **{ 609 expr_type: lambda self, e: self._annotate_binary(e) 610 for expr_type in subclasses(exp.__name__, exp.Binary) 611 }, 612 **{ 613 expr_type: _annotate_with_type_lambda(data_type) 614 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 615 for expr_type in expressions 616 }, 617 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 618 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 619 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 620 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 621 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 622 exp.Bracket: lambda self, e: self._annotate_bracket(e), 623 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 624 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 625 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 626 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 627 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 628 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 629 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 630 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 631 exp.Div: lambda self, e: self._annotate_div(e), 632 exp.Dot: lambda self, e: self._annotate_dot(e), 633 exp.Explode: lambda self, e: self._annotate_explode(e), 634 exp.Extract: lambda self, e: self._annotate_extract(e), 635 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 636 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 637 e, exp.DataType.build("ARRAY<DATE>") 638 ), 639 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 640 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 641 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 642 exp.Literal: lambda self, e: self._annotate_literal(e), 643 exp.Map: lambda self, e: self._annotate_map(e), 644 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 645 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 646 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 647 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 648 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 649 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 650 exp.Struct: lambda self, e: self._annotate_struct(e), 651 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 652 exp.Timestamp: lambda self, e: self._annotate_with_type( 653 e, 654 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 655 ), 656 exp.ToMap: lambda self, e: self._annotate_to_map(e), 657 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 658 exp.Unnest: lambda self, e: self._annotate_unnest(e), 659 exp.VarMap: lambda self, e: self._annotate_map(e), 660 } 661 662 @classmethod 663 def get_or_raise(cls, dialect: DialectType) -> Dialect: 664 """ 665 Look up a dialect in the global dialect registry and return it if it exists. 666 667 Args: 668 dialect: The target dialect. If this is a string, it can be optionally followed by 669 additional key-value pairs that are separated by commas and are used to specify 670 dialect settings, such as whether the dialect's identifiers are case-sensitive. 671 672 Example: 673 >>> dialect = dialect_class = get_or_raise("duckdb") 674 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 675 676 Returns: 677 The corresponding Dialect instance. 678 """ 679 680 if not dialect: 681 return cls() 682 if isinstance(dialect, _Dialect): 683 return dialect() 684 if isinstance(dialect, Dialect): 685 return dialect 686 if isinstance(dialect, str): 687 try: 688 dialect_name, *kv_strings = dialect.split(",") 689 kv_pairs = (kv.split("=") for kv in kv_strings) 690 kwargs = {} 691 for pair in kv_pairs: 692 key = pair[0].strip() 693 value: t.Union[bool | str | None] = None 694 695 if len(pair) == 1: 696 # Default initialize standalone settings to True 697 value = True 698 elif len(pair) == 2: 699 value = pair[1].strip() 700 701 # Coerce the value to boolean if it matches to the truthy/falsy values below 702 value_lower = value.lower() 703 if value_lower in ("true", "1"): 704 value = True 705 elif value_lower in ("false", "0"): 706 value = False 707 708 kwargs[key] = value 709 710 except ValueError: 711 raise ValueError( 712 f"Invalid dialect format: '{dialect}'. " 713 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 714 ) 715 716 result = cls.get(dialect_name.strip()) 717 if not result: 718 from difflib import get_close_matches 719 720 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 721 if similar: 722 similar = f" Did you mean {similar}?" 723 724 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 725 726 return result(**kwargs) 727 728 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 729 730 @classmethod 731 def format_time( 732 cls, expression: t.Optional[str | exp.Expression] 733 ) -> t.Optional[exp.Expression]: 734 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 735 if isinstance(expression, str): 736 return exp.Literal.string( 737 # the time formats are quoted 738 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 739 ) 740 741 if expression and expression.is_string: 742 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 743 744 return expression 745 746 def __init__(self, **kwargs) -> None: 747 normalization_strategy = kwargs.pop("normalization_strategy", None) 748 749 if normalization_strategy is None: 750 self.normalization_strategy = self.NORMALIZATION_STRATEGY 751 else: 752 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 753 754 self.settings = kwargs 755 756 def __eq__(self, other: t.Any) -> bool: 757 # Does not currently take dialect state into account 758 return type(self) == other 759 760 def __hash__(self) -> int: 761 # Does not currently take dialect state into account 762 return hash(type(self)) 763 764 def normalize_identifier(self, expression: E) -> E: 765 """ 766 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 767 768 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 769 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 770 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 771 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 772 773 There are also dialects like Spark, which are case-insensitive even when quotes are 774 present, and dialects like MySQL, whose resolution rules match those employed by the 775 underlying operating system, for example they may always be case-sensitive in Linux. 776 777 Finally, the normalization behavior of some engines can even be controlled through flags, 778 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 779 780 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 781 that it can analyze queries in the optimizer and successfully capture their semantics. 782 """ 783 if ( 784 isinstance(expression, exp.Identifier) 785 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 786 and ( 787 not expression.quoted 788 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 789 ) 790 ): 791 expression.set( 792 "this", 793 ( 794 expression.this.upper() 795 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 796 else expression.this.lower() 797 ), 798 ) 799 800 return expression 801 802 def case_sensitive(self, text: str) -> bool: 803 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 804 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 805 return False 806 807 unsafe = ( 808 str.islower 809 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 810 else str.isupper 811 ) 812 return any(unsafe(char) for char in text) 813 814 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 815 """Checks if text can be identified given an identify option. 816 817 Args: 818 text: The text to check. 819 identify: 820 `"always"` or `True`: Always returns `True`. 821 `"safe"`: Only returns `True` if the identifier is case-insensitive. 822 823 Returns: 824 Whether the given text can be identified. 825 """ 826 if identify is True or identify == "always": 827 return True 828 829 if identify == "safe": 830 return not self.case_sensitive(text) 831 832 return False 833 834 def quote_identifier(self, expression: E, identify: bool = True) -> E: 835 """ 836 Adds quotes to a given identifier. 837 838 Args: 839 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 840 identify: If set to `False`, the quotes will only be added if the identifier is deemed 841 "unsafe", with respect to its characters and this dialect's normalization strategy. 842 """ 843 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 844 name = expression.this 845 expression.set( 846 "quoted", 847 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 848 ) 849 850 return expression 851 852 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 853 if isinstance(path, exp.Literal): 854 path_text = path.name 855 if path.is_number: 856 path_text = f"[{path_text}]" 857 try: 858 return parse_json_path(path_text, self) 859 except ParseError as e: 860 logger.warning(f"Invalid JSON path syntax. {str(e)}") 861 862 return path 863 864 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 865 return self.parser(**opts).parse(self.tokenize(sql), sql) 866 867 def parse_into( 868 self, expression_type: exp.IntoType, sql: str, **opts 869 ) -> t.List[t.Optional[exp.Expression]]: 870 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 871 872 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 873 return self.generator(**opts).generate(expression, copy=copy) 874 875 def transpile(self, sql: str, **opts) -> t.List[str]: 876 return [ 877 self.generate(expression, copy=False, **opts) if expression else "" 878 for expression in self.parse(sql) 879 ] 880 881 def tokenize(self, sql: str) -> t.List[Token]: 882 return self.tokenizer.tokenize(sql) 883 884 @property 885 def tokenizer(self) -> Tokenizer: 886 return self.tokenizer_class(dialect=self) 887 888 @property 889 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 890 return self.jsonpath_tokenizer_class(dialect=self) 891 892 def parser(self, **opts) -> Parser: 893 return self.parser_class(dialect=self, **opts) 894 895 def generator(self, **opts) -> Generator: 896 return self.generator_class(dialect=self, **opts)
746 def __init__(self, **kwargs) -> None: 747 normalization_strategy = kwargs.pop("normalization_strategy", None) 748 749 if normalization_strategy is None: 750 self.normalization_strategy = self.NORMALIZATION_STRATEGY 751 else: 752 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 753 754 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
662 @classmethod 663 def get_or_raise(cls, dialect: DialectType) -> Dialect: 664 """ 665 Look up a dialect in the global dialect registry and return it if it exists. 666 667 Args: 668 dialect: The target dialect. If this is a string, it can be optionally followed by 669 additional key-value pairs that are separated by commas and are used to specify 670 dialect settings, such as whether the dialect's identifiers are case-sensitive. 671 672 Example: 673 >>> dialect = dialect_class = get_or_raise("duckdb") 674 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 675 676 Returns: 677 The corresponding Dialect instance. 678 """ 679 680 if not dialect: 681 return cls() 682 if isinstance(dialect, _Dialect): 683 return dialect() 684 if isinstance(dialect, Dialect): 685 return dialect 686 if isinstance(dialect, str): 687 try: 688 dialect_name, *kv_strings = dialect.split(",") 689 kv_pairs = (kv.split("=") for kv in kv_strings) 690 kwargs = {} 691 for pair in kv_pairs: 692 key = pair[0].strip() 693 value: t.Union[bool | str | None] = None 694 695 if len(pair) == 1: 696 # Default initialize standalone settings to True 697 value = True 698 elif len(pair) == 2: 699 value = pair[1].strip() 700 701 # Coerce the value to boolean if it matches to the truthy/falsy values below 702 value_lower = value.lower() 703 if value_lower in ("true", "1"): 704 value = True 705 elif value_lower in ("false", "0"): 706 value = False 707 708 kwargs[key] = value 709 710 except ValueError: 711 raise ValueError( 712 f"Invalid dialect format: '{dialect}'. " 713 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 714 ) 715 716 result = cls.get(dialect_name.strip()) 717 if not result: 718 from difflib import get_close_matches 719 720 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 721 if similar: 722 similar = f" Did you mean {similar}?" 723 724 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 725 726 return result(**kwargs) 727 728 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
730 @classmethod 731 def format_time( 732 cls, expression: t.Optional[str | exp.Expression] 733 ) -> t.Optional[exp.Expression]: 734 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 735 if isinstance(expression, str): 736 return exp.Literal.string( 737 # the time formats are quoted 738 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 739 ) 740 741 if expression and expression.is_string: 742 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 743 744 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
764 def normalize_identifier(self, expression: E) -> E: 765 """ 766 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 767 768 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 769 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 770 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 771 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 772 773 There are also dialects like Spark, which are case-insensitive even when quotes are 774 present, and dialects like MySQL, whose resolution rules match those employed by the 775 underlying operating system, for example they may always be case-sensitive in Linux. 776 777 Finally, the normalization behavior of some engines can even be controlled through flags, 778 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 779 780 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 781 that it can analyze queries in the optimizer and successfully capture their semantics. 782 """ 783 if ( 784 isinstance(expression, exp.Identifier) 785 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 786 and ( 787 not expression.quoted 788 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 789 ) 790 ): 791 expression.set( 792 "this", 793 ( 794 expression.this.upper() 795 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 796 else expression.this.lower() 797 ), 798 ) 799 800 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
802 def case_sensitive(self, text: str) -> bool: 803 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 804 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 805 return False 806 807 unsafe = ( 808 str.islower 809 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 810 else str.isupper 811 ) 812 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
814 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 815 """Checks if text can be identified given an identify option. 816 817 Args: 818 text: The text to check. 819 identify: 820 `"always"` or `True`: Always returns `True`. 821 `"safe"`: Only returns `True` if the identifier is case-insensitive. 822 823 Returns: 824 Whether the given text can be identified. 825 """ 826 if identify is True or identify == "always": 827 return True 828 829 if identify == "safe": 830 return not self.case_sensitive(text) 831 832 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
834 def quote_identifier(self, expression: E, identify: bool = True) -> E: 835 """ 836 Adds quotes to a given identifier. 837 838 Args: 839 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 840 identify: If set to `False`, the quotes will only be added if the identifier is deemed 841 "unsafe", with respect to its characters and this dialect's normalization strategy. 842 """ 843 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 844 name = expression.this 845 expression.set( 846 "quoted", 847 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 848 ) 849 850 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
852 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 853 if isinstance(path, exp.Literal): 854 path_text = path.name 855 if path.is_number: 856 path_text = f"[{path_text}]" 857 try: 858 return parse_json_path(path_text, self) 859 except ParseError as e: 860 logger.warning(f"Invalid JSON path syntax. {str(e)}") 861 862 return path
912def if_sql( 913 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 914) -> t.Callable[[Generator, exp.If], str]: 915 def _if_sql(self: Generator, expression: exp.If) -> str: 916 return self.func( 917 name, 918 expression.this, 919 expression.args.get("true"), 920 expression.args.get("false") or false_value, 921 ) 922 923 return _if_sql
926def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 927 this = expression.this 928 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 929 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 930 931 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
997def str_position_sql( 998 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 999) -> str: 1000 this = self.sql(expression, "this") 1001 substr = self.sql(expression, "substr") 1002 position = self.sql(expression, "position") 1003 instance = expression.args.get("instance") if generate_instance else None 1004 position_offset = "" 1005 1006 if position: 1007 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1008 this = self.func("SUBSTR", this, position) 1009 position_offset = f" + {position} - 1" 1010 1011 return self.func("STRPOS", this, substr, instance) + position_offset
1020def var_map_sql( 1021 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1022) -> str: 1023 keys = expression.args["keys"] 1024 values = expression.args["values"] 1025 1026 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1027 self.unsupported("Cannot convert array columns into map.") 1028 return self.func(map_func_name, keys, values) 1029 1030 args = [] 1031 for key, value in zip(keys.expressions, values.expressions): 1032 args.append(self.sql(key)) 1033 args.append(self.sql(value)) 1034 1035 return self.func(map_func_name, *args)
1038def build_formatted_time( 1039 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1040) -> t.Callable[[t.List], E]: 1041 """Helper used for time expressions. 1042 1043 Args: 1044 exp_class: the expression class to instantiate. 1045 dialect: target sql dialect. 1046 default: the default format, True being time. 1047 1048 Returns: 1049 A callable that can be used to return the appropriately formatted time expression. 1050 """ 1051 1052 def _builder(args: t.List): 1053 return exp_class( 1054 this=seq_get(args, 0), 1055 format=Dialect[dialect].format_time( 1056 seq_get(args, 1) 1057 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1058 ), 1059 ) 1060 1061 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1064def time_format( 1065 dialect: DialectType = None, 1066) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1067 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1068 """ 1069 Returns the time format for a given expression, unless it's equivalent 1070 to the default time format of the dialect of interest. 1071 """ 1072 time_format = self.format_time(expression) 1073 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1074 1075 return _time_format
1078def build_date_delta( 1079 exp_class: t.Type[E], 1080 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1081 default_unit: t.Optional[str] = "DAY", 1082) -> t.Callable[[t.List], E]: 1083 def _builder(args: t.List) -> E: 1084 unit_based = len(args) == 3 1085 this = args[2] if unit_based else seq_get(args, 0) 1086 unit = None 1087 if unit_based or default_unit: 1088 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1089 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1090 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1091 1092 return _builder
1095def build_date_delta_with_interval( 1096 expression_class: t.Type[E], 1097) -> t.Callable[[t.List], t.Optional[E]]: 1098 def _builder(args: t.List) -> t.Optional[E]: 1099 if len(args) < 2: 1100 return None 1101 1102 interval = args[1] 1103 1104 if not isinstance(interval, exp.Interval): 1105 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1106 1107 expression = interval.this 1108 if expression and expression.is_string: 1109 expression = exp.Literal.number(expression.this) 1110 1111 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1112 1113 return _builder
1116def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1117 unit = seq_get(args, 0) 1118 this = seq_get(args, 1) 1119 1120 if isinstance(this, exp.Cast) and this.is_type("date"): 1121 return exp.DateTrunc(unit=unit, this=this) 1122 return exp.TimestampTrunc(this=this, unit=unit)
1125def date_add_interval_sql( 1126 data_type: str, kind: str 1127) -> t.Callable[[Generator, exp.Expression], str]: 1128 def func(self: Generator, expression: exp.Expression) -> str: 1129 this = self.sql(expression, "this") 1130 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1131 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1132 1133 return func
1136def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1137 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1138 args = [unit_to_str(expression), expression.this] 1139 if zone: 1140 args.append(expression.args.get("zone")) 1141 return self.func("DATE_TRUNC", *args) 1142 1143 return _timestamptrunc_sql
1146def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1147 zone = expression.args.get("zone") 1148 if not zone: 1149 from sqlglot.optimizer.annotate_types import annotate_types 1150 1151 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1152 return self.sql(exp.cast(expression.this, target_type)) 1153 if zone.name.lower() in TIMEZONES: 1154 return self.sql( 1155 exp.AtTimeZone( 1156 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1157 zone=zone, 1158 ) 1159 ) 1160 return self.func("TIMESTAMP", expression.this, zone)
1163def no_time_sql(self: Generator, expression: exp.Time) -> str: 1164 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1165 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1166 expr = exp.cast( 1167 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1168 ) 1169 return self.sql(expr)
1172def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1173 this = expression.this 1174 expr = expression.expression 1175 1176 if expr.name.lower() in TIMEZONES: 1177 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1178 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1179 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1180 return self.sql(this) 1181 1182 this = exp.cast(this, exp.DataType.Type.DATE) 1183 expr = exp.cast(expr, exp.DataType.Type.TIME) 1184 1185 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1226def encode_decode_sql( 1227 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1228) -> str: 1229 charset = expression.args.get("charset") 1230 if charset and charset.name.lower() != "utf-8": 1231 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1232 1233 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1246def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1247 cond = expression.this 1248 1249 if isinstance(expression.this, exp.Distinct): 1250 cond = expression.this.expressions[0] 1251 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1252 1253 return self.func("sum", exp.func("if", cond, 1, 0))
1256def trim_sql(self: Generator, expression: exp.Trim) -> str: 1257 target = self.sql(expression, "this") 1258 trim_type = self.sql(expression, "position") 1259 remove_chars = self.sql(expression, "expression") 1260 collation = self.sql(expression, "collation") 1261 1262 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1263 if not remove_chars and not collation: 1264 return self.trim_sql(expression) 1265 1266 trim_type = f"{trim_type} " if trim_type else "" 1267 remove_chars = f"{remove_chars} " if remove_chars else "" 1268 from_part = "FROM " if trim_type or remove_chars else "" 1269 collation = f" COLLATE {collation}" if collation else "" 1270 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1291def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1292 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1293 if bad_args: 1294 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1295 1296 return self.func( 1297 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1298 )
1301def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1302 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1303 if bad_args: 1304 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1305 1306 return self.func( 1307 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1308 )
1311def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1312 names = [] 1313 for agg in aggregations: 1314 if isinstance(agg, exp.Alias): 1315 names.append(agg.alias) 1316 else: 1317 """ 1318 This case corresponds to aggregations without aliases being used as suffixes 1319 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1320 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1321 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1322 """ 1323 agg_all_unquoted = agg.transform( 1324 lambda node: ( 1325 exp.Identifier(this=node.name, quoted=False) 1326 if isinstance(node, exp.Identifier) 1327 else node 1328 ) 1329 ) 1330 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1331 1332 return names
1372def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1373 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1374 if expression.args.get("count"): 1375 self.unsupported(f"Only two arguments are supported in function {name}.") 1376 1377 return self.func(name, expression.this, expression.expression) 1378 1379 return _arg_max_or_min_sql
1382def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1383 this = expression.this.copy() 1384 1385 return_type = expression.return_type 1386 if return_type.is_type(exp.DataType.Type.DATE): 1387 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1388 # can truncate timestamp strings, because some dialects can't cast them to DATE 1389 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1390 1391 expression.this.replace(exp.cast(this, return_type)) 1392 return expression
1395def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1396 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1397 if cast and isinstance(expression, exp.TsOrDsAdd): 1398 expression = ts_or_ds_add_cast(expression) 1399 1400 return self.func( 1401 name, 1402 unit_to_var(expression), 1403 expression.expression, 1404 expression.this, 1405 ) 1406 1407 return _delta_sql
1410def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1411 unit = expression.args.get("unit") 1412 1413 if isinstance(unit, exp.Placeholder): 1414 return unit 1415 if unit: 1416 return exp.Literal.string(unit.name) 1417 return exp.Literal.string(default) if default else None
1447def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1448 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1449 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1450 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1451 1452 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1455def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1456 """Remove table refs from columns in when statements.""" 1457 alias = expression.this.args.get("alias") 1458 1459 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1460 return self.dialect.normalize_identifier(identifier).name if identifier else None 1461 1462 targets = {normalize(expression.this.this)} 1463 1464 if alias: 1465 targets.add(normalize(alias.this)) 1466 1467 for when in expression.expressions: 1468 when.transform( 1469 lambda node: ( 1470 exp.column(node.this) 1471 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1472 else node 1473 ), 1474 copy=False, 1475 ) 1476 1477 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1480def build_json_extract_path( 1481 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1482) -> t.Callable[[t.List], F]: 1483 def _builder(args: t.List) -> F: 1484 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1485 for arg in args[1:]: 1486 if not isinstance(arg, exp.Literal): 1487 # We use the fallback parser because we can't really transpile non-literals safely 1488 return expr_type.from_arg_list(args) 1489 1490 text = arg.name 1491 if is_int(text): 1492 index = int(text) 1493 segments.append( 1494 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1495 ) 1496 else: 1497 segments.append(exp.JSONPathKey(this=text)) 1498 1499 # This is done to avoid failing in the expression validator due to the arg count 1500 del args[2:] 1501 return expr_type( 1502 this=seq_get(args, 0), 1503 expression=exp.JSONPath(expressions=segments), 1504 only_json_types=arrow_req_json_type, 1505 ) 1506 1507 return _builder
1510def json_extract_segments( 1511 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1512) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1513 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1514 path = expression.expression 1515 if not isinstance(path, exp.JSONPath): 1516 return rename_func(name)(self, expression) 1517 1518 segments = [] 1519 for segment in path.expressions: 1520 path = self.sql(segment) 1521 if path: 1522 if isinstance(segment, exp.JSONPathPart) and ( 1523 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1524 ): 1525 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1526 1527 segments.append(path) 1528 1529 if op: 1530 return f" {op} ".join([self.sql(expression.this), *segments]) 1531 return self.func(name, expression.this, *segments) 1532 1533 return _json_extract_segments
1543def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1544 cond = expression.expression 1545 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1546 alias = cond.expressions[0] 1547 cond = cond.this 1548 elif isinstance(cond, exp.Predicate): 1549 alias = "_u" 1550 else: 1551 self.unsupported("Unsupported filter condition") 1552 return "" 1553 1554 unnest = exp.Unnest(expressions=[expression.this]) 1555 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1556 return self.sql(exp.Array(expressions=[filtered]))
1568def build_default_decimal_type( 1569 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1570) -> t.Callable[[exp.DataType], exp.DataType]: 1571 def _builder(dtype: exp.DataType) -> exp.DataType: 1572 if dtype.expressions or precision is None: 1573 return dtype 1574 1575 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1576 return exp.DataType.build(f"DECIMAL({params})") 1577 1578 return _builder
1581def build_timestamp_from_parts(args: t.List) -> exp.Func: 1582 if len(args) == 2: 1583 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1584 # so we parse this into Anonymous for now instead of introducing complexity 1585 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1586 1587 return exp.TimestampFromParts.from_arg_list(args)
1594def sequence_sql(self: Generator, expression: exp.GenerateSeries): 1595 start = expression.args["start"] 1596 end = expression.args["end"] 1597 step = expression.args.get("step") 1598 1599 if isinstance(start, exp.Cast): 1600 target_type = start.to 1601 elif isinstance(end, exp.Cast): 1602 target_type = end.to 1603 else: 1604 target_type = None 1605 1606 if target_type and target_type.is_type("timestamp"): 1607 if target_type is start.to: 1608 end = exp.cast(end, target_type) 1609 else: 1610 start = exp.cast(start, target_type) 1611 1612 return self.func("SEQUENCE", start, end, step)