Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        base = seq_get(bases, 0)
 114        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 115        base_parser = (getattr(base, "parser_class", Parser),)
 116        base_generator = (getattr(base, "generator_class", Generator),)
 117
 118        klass.tokenizer_class = klass.__dict__.get(
 119            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 120        )
 121        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 122        klass.generator_class = klass.__dict__.get(
 123            "Generator", type("Generator", base_generator, {})
 124        )
 125
 126        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 127        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 128            klass.tokenizer_class._IDENTIFIERS.items()
 129        )[0]
 130
 131        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 132            return next(
 133                (
 134                    (s, e)
 135                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 136                    if t == token_type
 137                ),
 138                (None, None),
 139            )
 140
 141        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 142        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 143        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 144        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 145
 146        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 147            klass.UNESCAPED_SEQUENCES = {
 148                "\\a": "\a",
 149                "\\b": "\b",
 150                "\\f": "\f",
 151                "\\n": "\n",
 152                "\\r": "\r",
 153                "\\t": "\t",
 154                "\\v": "\v",
 155                "\\\\": "\\",
 156                **klass.UNESCAPED_SEQUENCES,
 157            }
 158
 159        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 160
 161        if enum not in ("", "bigquery"):
 162            klass.generator_class.SELECT_KINDS = ()
 163
 164        if enum not in ("", "athena", "presto", "trino"):
 165            klass.generator_class.TRY_SUPPORTED = False
 166
 167        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 168            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 169            for modifier in ("cluster", "distribute", "sort"):
 170                modifier_transforms.pop(modifier, None)
 171
 172            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 173
 174        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 175            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 176                TokenType.ANTI,
 177                TokenType.SEMI,
 178            }
 179
 180        return klass
 181
 182
 183class Dialect(metaclass=_Dialect):
 184    INDEX_OFFSET = 0
 185    """The base index offset for arrays."""
 186
 187    WEEK_OFFSET = 0
 188    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 189
 190    UNNEST_COLUMN_ONLY = False
 191    """Whether `UNNEST` table aliases are treated as column aliases."""
 192
 193    ALIAS_POST_TABLESAMPLE = False
 194    """Whether the table alias comes after tablesample."""
 195
 196    TABLESAMPLE_SIZE_IS_PERCENT = False
 197    """Whether a size in the table sample clause represents percentage."""
 198
 199    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 200    """Specifies the strategy according to which identifiers should be normalized."""
 201
 202    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 203    """Whether an unquoted identifier can start with a digit."""
 204
 205    DPIPE_IS_STRING_CONCAT = True
 206    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 207
 208    STRICT_STRING_CONCAT = False
 209    """Whether `CONCAT`'s arguments must be strings."""
 210
 211    SUPPORTS_USER_DEFINED_TYPES = True
 212    """Whether user-defined data types are supported."""
 213
 214    SUPPORTS_SEMI_ANTI_JOIN = True
 215    """Whether `SEMI` or `ANTI` joins are supported."""
 216
 217    NORMALIZE_FUNCTIONS: bool | str = "upper"
 218    """
 219    Determines how function names are going to be normalized.
 220    Possible values:
 221        "upper" or True: Convert names to uppercase.
 222        "lower": Convert names to lowercase.
 223        False: Disables function name normalization.
 224    """
 225
 226    LOG_BASE_FIRST: t.Optional[bool] = True
 227    """
 228    Whether the base comes first in the `LOG` function.
 229    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 230    """
 231
 232    NULL_ORDERING = "nulls_are_small"
 233    """
 234    Default `NULL` ordering method to use if not explicitly set.
 235    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 236    """
 237
 238    TYPED_DIVISION = False
 239    """
 240    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 241    False means `a / b` is always float division.
 242    True means `a / b` is integer division if both `a` and `b` are integers.
 243    """
 244
 245    SAFE_DIVISION = False
 246    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 247
 248    CONCAT_COALESCE = False
 249    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 250
 251    DATE_FORMAT = "'%Y-%m-%d'"
 252    DATEINT_FORMAT = "'%Y%m%d'"
 253    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 254
 255    TIME_MAPPING: t.Dict[str, str] = {}
 256    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 257
 258    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 259    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 260    FORMAT_MAPPING: t.Dict[str, str] = {}
 261    """
 262    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 263    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 264    """
 265
 266    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 267    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 268
 269    PSEUDOCOLUMNS: t.Set[str] = set()
 270    """
 271    Columns that are auto-generated by the engine corresponding to this dialect.
 272    For example, such columns may be excluded from `SELECT *` queries.
 273    """
 274
 275    PREFER_CTE_ALIAS_COLUMN = False
 276    """
 277    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 278    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 279    any projection aliases in the subquery.
 280
 281    For example,
 282        WITH y(c) AS (
 283            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 284        ) SELECT c FROM y;
 285
 286        will be rewritten as
 287
 288        WITH y(c) AS (
 289            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 290        ) SELECT c FROM y;
 291    """
 292
 293    # --- Autofilled ---
 294
 295    tokenizer_class = Tokenizer
 296    parser_class = Parser
 297    generator_class = Generator
 298
 299    # A trie of the time_mapping keys
 300    TIME_TRIE: t.Dict = {}
 301    FORMAT_TRIE: t.Dict = {}
 302
 303    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 304    INVERSE_TIME_TRIE: t.Dict = {}
 305
 306    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 307
 308    # Delimiters for string literals and identifiers
 309    QUOTE_START = "'"
 310    QUOTE_END = "'"
 311    IDENTIFIER_START = '"'
 312    IDENTIFIER_END = '"'
 313
 314    # Delimiters for bit, hex, byte and unicode literals
 315    BIT_START: t.Optional[str] = None
 316    BIT_END: t.Optional[str] = None
 317    HEX_START: t.Optional[str] = None
 318    HEX_END: t.Optional[str] = None
 319    BYTE_START: t.Optional[str] = None
 320    BYTE_END: t.Optional[str] = None
 321    UNICODE_START: t.Optional[str] = None
 322    UNICODE_END: t.Optional[str] = None
 323
 324    # Separator of COPY statement parameters
 325    COPY_PARAMS_ARE_CSV = True
 326
 327    @classmethod
 328    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 329        """
 330        Look up a dialect in the global dialect registry and return it if it exists.
 331
 332        Args:
 333            dialect: The target dialect. If this is a string, it can be optionally followed by
 334                additional key-value pairs that are separated by commas and are used to specify
 335                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 336
 337        Example:
 338            >>> dialect = dialect_class = get_or_raise("duckdb")
 339            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 340
 341        Returns:
 342            The corresponding Dialect instance.
 343        """
 344
 345        if not dialect:
 346            return cls()
 347        if isinstance(dialect, _Dialect):
 348            return dialect()
 349        if isinstance(dialect, Dialect):
 350            return dialect
 351        if isinstance(dialect, str):
 352            try:
 353                dialect_name, *kv_pairs = dialect.split(",")
 354                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 355            except ValueError:
 356                raise ValueError(
 357                    f"Invalid dialect format: '{dialect}'. "
 358                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 359                )
 360
 361            result = cls.get(dialect_name.strip())
 362            if not result:
 363                from difflib import get_close_matches
 364
 365                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 366                if similar:
 367                    similar = f" Did you mean {similar}?"
 368
 369                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 370
 371            return result(**kwargs)
 372
 373        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 374
 375    @classmethod
 376    def format_time(
 377        cls, expression: t.Optional[str | exp.Expression]
 378    ) -> t.Optional[exp.Expression]:
 379        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 380        if isinstance(expression, str):
 381            return exp.Literal.string(
 382                # the time formats are quoted
 383                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 384            )
 385
 386        if expression and expression.is_string:
 387            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 388
 389        return expression
 390
 391    def __init__(self, **kwargs) -> None:
 392        normalization_strategy = kwargs.get("normalization_strategy")
 393
 394        if normalization_strategy is None:
 395            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 396        else:
 397            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 398
 399    def __eq__(self, other: t.Any) -> bool:
 400        # Does not currently take dialect state into account
 401        return type(self) == other
 402
 403    def __hash__(self) -> int:
 404        # Does not currently take dialect state into account
 405        return hash(type(self))
 406
 407    def normalize_identifier(self, expression: E) -> E:
 408        """
 409        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 410
 411        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 412        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 413        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 414        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 415
 416        There are also dialects like Spark, which are case-insensitive even when quotes are
 417        present, and dialects like MySQL, whose resolution rules match those employed by the
 418        underlying operating system, for example they may always be case-sensitive in Linux.
 419
 420        Finally, the normalization behavior of some engines can even be controlled through flags,
 421        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 422
 423        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 424        that it can analyze queries in the optimizer and successfully capture their semantics.
 425        """
 426        if (
 427            isinstance(expression, exp.Identifier)
 428            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 429            and (
 430                not expression.quoted
 431                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 432            )
 433        ):
 434            expression.set(
 435                "this",
 436                (
 437                    expression.this.upper()
 438                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 439                    else expression.this.lower()
 440                ),
 441            )
 442
 443        return expression
 444
 445    def case_sensitive(self, text: str) -> bool:
 446        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 447        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 448            return False
 449
 450        unsafe = (
 451            str.islower
 452            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 453            else str.isupper
 454        )
 455        return any(unsafe(char) for char in text)
 456
 457    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 458        """Checks if text can be identified given an identify option.
 459
 460        Args:
 461            text: The text to check.
 462            identify:
 463                `"always"` or `True`: Always returns `True`.
 464                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 465
 466        Returns:
 467            Whether the given text can be identified.
 468        """
 469        if identify is True or identify == "always":
 470            return True
 471
 472        if identify == "safe":
 473            return not self.case_sensitive(text)
 474
 475        return False
 476
 477    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 478        """
 479        Adds quotes to a given identifier.
 480
 481        Args:
 482            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 483            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 484                "unsafe", with respect to its characters and this dialect's normalization strategy.
 485        """
 486        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 487            name = expression.this
 488            expression.set(
 489                "quoted",
 490                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 491            )
 492
 493        return expression
 494
 495    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 496        if isinstance(path, exp.Literal):
 497            path_text = path.name
 498            if path.is_number:
 499                path_text = f"[{path_text}]"
 500
 501            try:
 502                return parse_json_path(path_text)
 503            except ParseError as e:
 504                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 505
 506        return path
 507
 508    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 509        return self.parser(**opts).parse(self.tokenize(sql), sql)
 510
 511    def parse_into(
 512        self, expression_type: exp.IntoType, sql: str, **opts
 513    ) -> t.List[t.Optional[exp.Expression]]:
 514        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 515
 516    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 517        return self.generator(**opts).generate(expression, copy=copy)
 518
 519    def transpile(self, sql: str, **opts) -> t.List[str]:
 520        return [
 521            self.generate(expression, copy=False, **opts) if expression else ""
 522            for expression in self.parse(sql)
 523        ]
 524
 525    def tokenize(self, sql: str) -> t.List[Token]:
 526        return self.tokenizer.tokenize(sql)
 527
 528    @property
 529    def tokenizer(self) -> Tokenizer:
 530        if not hasattr(self, "_tokenizer"):
 531            self._tokenizer = self.tokenizer_class(dialect=self)
 532        return self._tokenizer
 533
 534    def parser(self, **opts) -> Parser:
 535        return self.parser_class(dialect=self, **opts)
 536
 537    def generator(self, **opts) -> Generator:
 538        return self.generator_class(dialect=self, **opts)
 539
 540
 541DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 542
 543
 544def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 545    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 546
 547
 548def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 549    if expression.args.get("accuracy"):
 550        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 551    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 552
 553
 554def if_sql(
 555    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 556) -> t.Callable[[Generator, exp.If], str]:
 557    def _if_sql(self: Generator, expression: exp.If) -> str:
 558        return self.func(
 559            name,
 560            expression.this,
 561            expression.args.get("true"),
 562            expression.args.get("false") or false_value,
 563        )
 564
 565    return _if_sql
 566
 567
 568def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 569    this = expression.this
 570    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 571        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 572
 573    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 574
 575
 576def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 577    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 578
 579
 580def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 581    elem = seq_get(expression.expressions, 0)
 582    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 583        return self.func("ARRAY", elem)
 584    return inline_array_sql(self, expression)
 585
 586
 587def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 588    return self.like_sql(
 589        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 590    )
 591
 592
 593def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 594    zone = self.sql(expression, "this")
 595    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 596
 597
 598def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 599    if expression.args.get("recursive"):
 600        self.unsupported("Recursive CTEs are unsupported")
 601        expression.args["recursive"] = False
 602    return self.with_sql(expression)
 603
 604
 605def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 606    n = self.sql(expression, "this")
 607    d = self.sql(expression, "expression")
 608    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 609
 610
 611def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 612    self.unsupported("TABLESAMPLE unsupported")
 613    return self.sql(expression.this)
 614
 615
 616def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 617    self.unsupported("PIVOT unsupported")
 618    return ""
 619
 620
 621def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 622    return self.cast_sql(expression)
 623
 624
 625def no_comment_column_constraint_sql(
 626    self: Generator, expression: exp.CommentColumnConstraint
 627) -> str:
 628    self.unsupported("CommentColumnConstraint unsupported")
 629    return ""
 630
 631
 632def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 633    self.unsupported("MAP_FROM_ENTRIES unsupported")
 634    return ""
 635
 636
 637def str_position_sql(
 638    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 639) -> str:
 640    this = self.sql(expression, "this")
 641    substr = self.sql(expression, "substr")
 642    position = self.sql(expression, "position")
 643    instance = expression.args.get("instance") if generate_instance else None
 644    position_offset = ""
 645
 646    if position:
 647        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 648        this = self.func("SUBSTR", this, position)
 649        position_offset = f" + {position} - 1"
 650
 651    return self.func("STRPOS", this, substr, instance) + position_offset
 652
 653
 654def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 655    return (
 656        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 657    )
 658
 659
 660def var_map_sql(
 661    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 662) -> str:
 663    keys = expression.args["keys"]
 664    values = expression.args["values"]
 665
 666    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 667        self.unsupported("Cannot convert array columns into map.")
 668        return self.func(map_func_name, keys, values)
 669
 670    args = []
 671    for key, value in zip(keys.expressions, values.expressions):
 672        args.append(self.sql(key))
 673        args.append(self.sql(value))
 674
 675    return self.func(map_func_name, *args)
 676
 677
 678def build_formatted_time(
 679    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 680) -> t.Callable[[t.List], E]:
 681    """Helper used for time expressions.
 682
 683    Args:
 684        exp_class: the expression class to instantiate.
 685        dialect: target sql dialect.
 686        default: the default format, True being time.
 687
 688    Returns:
 689        A callable that can be used to return the appropriately formatted time expression.
 690    """
 691
 692    def _builder(args: t.List):
 693        return exp_class(
 694            this=seq_get(args, 0),
 695            format=Dialect[dialect].format_time(
 696                seq_get(args, 1)
 697                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 698            ),
 699        )
 700
 701    return _builder
 702
 703
 704def time_format(
 705    dialect: DialectType = None,
 706) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 707    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 708        """
 709        Returns the time format for a given expression, unless it's equivalent
 710        to the default time format of the dialect of interest.
 711        """
 712        time_format = self.format_time(expression)
 713        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 714
 715    return _time_format
 716
 717
 718def build_date_delta(
 719    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 720) -> t.Callable[[t.List], E]:
 721    def _builder(args: t.List) -> E:
 722        unit_based = len(args) == 3
 723        this = args[2] if unit_based else seq_get(args, 0)
 724        unit = args[0] if unit_based else exp.Literal.string("DAY")
 725        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 726        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 727
 728    return _builder
 729
 730
 731def build_date_delta_with_interval(
 732    expression_class: t.Type[E],
 733) -> t.Callable[[t.List], t.Optional[E]]:
 734    def _builder(args: t.List) -> t.Optional[E]:
 735        if len(args) < 2:
 736            return None
 737
 738        interval = args[1]
 739
 740        if not isinstance(interval, exp.Interval):
 741            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 742
 743        expression = interval.this
 744        if expression and expression.is_string:
 745            expression = exp.Literal.number(expression.this)
 746
 747        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 748
 749    return _builder
 750
 751
 752def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 753    unit = seq_get(args, 0)
 754    this = seq_get(args, 1)
 755
 756    if isinstance(this, exp.Cast) and this.is_type("date"):
 757        return exp.DateTrunc(unit=unit, this=this)
 758    return exp.TimestampTrunc(this=this, unit=unit)
 759
 760
 761def date_add_interval_sql(
 762    data_type: str, kind: str
 763) -> t.Callable[[Generator, exp.Expression], str]:
 764    def func(self: Generator, expression: exp.Expression) -> str:
 765        this = self.sql(expression, "this")
 766        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 767        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 768
 769    return func
 770
 771
 772def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 773    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
 774
 775
 776def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 777    if not expression.expression:
 778        from sqlglot.optimizer.annotate_types import annotate_types
 779
 780        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 781        return self.sql(exp.cast(expression.this, target_type))
 782    if expression.text("expression").lower() in TIMEZONES:
 783        return self.sql(
 784            exp.AtTimeZone(
 785                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 786                zone=expression.expression,
 787            )
 788        )
 789    return self.func("TIMESTAMP", expression.this, expression.expression)
 790
 791
 792def locate_to_strposition(args: t.List) -> exp.Expression:
 793    return exp.StrPosition(
 794        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 795    )
 796
 797
 798def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 799    return self.func(
 800        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 801    )
 802
 803
 804def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 805    return self.sql(
 806        exp.Substring(
 807            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 808        )
 809    )
 810
 811
 812def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 813    return self.sql(
 814        exp.Substring(
 815            this=expression.this,
 816            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 817        )
 818    )
 819
 820
 821def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 822    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 823
 824
 825def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 826    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 827
 828
 829# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 830def encode_decode_sql(
 831    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 832) -> str:
 833    charset = expression.args.get("charset")
 834    if charset and charset.name.lower() != "utf-8":
 835        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 836
 837    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 838
 839
 840def min_or_least(self: Generator, expression: exp.Min) -> str:
 841    name = "LEAST" if expression.expressions else "MIN"
 842    return rename_func(name)(self, expression)
 843
 844
 845def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 846    name = "GREATEST" if expression.expressions else "MAX"
 847    return rename_func(name)(self, expression)
 848
 849
 850def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 851    cond = expression.this
 852
 853    if isinstance(expression.this, exp.Distinct):
 854        cond = expression.this.expressions[0]
 855        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 856
 857    return self.func("sum", exp.func("if", cond, 1, 0))
 858
 859
 860def trim_sql(self: Generator, expression: exp.Trim) -> str:
 861    target = self.sql(expression, "this")
 862    trim_type = self.sql(expression, "position")
 863    remove_chars = self.sql(expression, "expression")
 864    collation = self.sql(expression, "collation")
 865
 866    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 867    if not remove_chars and not collation:
 868        return self.trim_sql(expression)
 869
 870    trim_type = f"{trim_type} " if trim_type else ""
 871    remove_chars = f"{remove_chars} " if remove_chars else ""
 872    from_part = "FROM " if trim_type or remove_chars else ""
 873    collation = f" COLLATE {collation}" if collation else ""
 874    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 875
 876
 877def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 878    return self.func("STRPTIME", expression.this, self.format_time(expression))
 879
 880
 881def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 882    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 883
 884
 885def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 886    delim, *rest_args = expression.expressions
 887    return self.sql(
 888        reduce(
 889            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 890            rest_args,
 891        )
 892    )
 893
 894
 895def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 896    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 897    if bad_args:
 898        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 899
 900    return self.func(
 901        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 902    )
 903
 904
 905def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 906    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
 907    if bad_args:
 908        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 909
 910    return self.func(
 911        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 912    )
 913
 914
 915def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 916    names = []
 917    for agg in aggregations:
 918        if isinstance(agg, exp.Alias):
 919            names.append(agg.alias)
 920        else:
 921            """
 922            This case corresponds to aggregations without aliases being used as suffixes
 923            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 924            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 925            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 926            """
 927            agg_all_unquoted = agg.transform(
 928                lambda node: (
 929                    exp.Identifier(this=node.name, quoted=False)
 930                    if isinstance(node, exp.Identifier)
 931                    else node
 932                )
 933            )
 934            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 935
 936    return names
 937
 938
 939def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 940    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 941
 942
 943# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 944def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 945    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 946
 947
 948def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 949    return self.func("MAX", expression.this)
 950
 951
 952def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 953    a = self.sql(expression.left)
 954    b = self.sql(expression.right)
 955    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 956
 957
 958def is_parse_json(expression: exp.Expression) -> bool:
 959    return isinstance(expression, exp.ParseJSON) or (
 960        isinstance(expression, exp.Cast) and expression.is_type("json")
 961    )
 962
 963
 964def isnull_to_is_null(args: t.List) -> exp.Expression:
 965    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 966
 967
 968def generatedasidentitycolumnconstraint_sql(
 969    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 970) -> str:
 971    start = self.sql(expression, "start") or "1"
 972    increment = self.sql(expression, "increment") or "1"
 973    return f"IDENTITY({start}, {increment})"
 974
 975
 976def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 977    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 978        if expression.args.get("count"):
 979            self.unsupported(f"Only two arguments are supported in function {name}.")
 980
 981        return self.func(name, expression.this, expression.expression)
 982
 983    return _arg_max_or_min_sql
 984
 985
 986def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 987    this = expression.this.copy()
 988
 989    return_type = expression.return_type
 990    if return_type.is_type(exp.DataType.Type.DATE):
 991        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 992        # can truncate timestamp strings, because some dialects can't cast them to DATE
 993        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 994
 995    expression.this.replace(exp.cast(this, return_type))
 996    return expression
 997
 998
 999def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1000    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1001        if cast and isinstance(expression, exp.TsOrDsAdd):
1002            expression = ts_or_ds_add_cast(expression)
1003
1004        return self.func(
1005            name,
1006            unit_to_var(expression),
1007            expression.expression,
1008            expression.this,
1009        )
1010
1011    return _delta_sql
1012
1013
1014def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1015    unit = expression.args.get("unit")
1016
1017    if isinstance(unit, exp.Placeholder):
1018        return unit
1019    if unit:
1020        return exp.Literal.string(unit.name)
1021    return exp.Literal.string(default) if default else None
1022
1023
1024def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1025    unit = expression.args.get("unit")
1026
1027    if isinstance(unit, (exp.Var, exp.Placeholder)):
1028        return unit
1029    return exp.Var(this=default) if default else None
1030
1031
1032def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1033    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1034    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1035    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1036
1037    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1038
1039
1040def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1041    """Remove table refs from columns in when statements."""
1042    alias = expression.this.args.get("alias")
1043
1044    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1045        return self.dialect.normalize_identifier(identifier).name if identifier else None
1046
1047    targets = {normalize(expression.this.this)}
1048
1049    if alias:
1050        targets.add(normalize(alias.this))
1051
1052    for when in expression.expressions:
1053        when.transform(
1054            lambda node: (
1055                exp.column(node.this)
1056                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1057                else node
1058            ),
1059            copy=False,
1060        )
1061
1062    return self.merge_sql(expression)
1063
1064
1065def build_json_extract_path(
1066    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1067) -> t.Callable[[t.List], F]:
1068    def _builder(args: t.List) -> F:
1069        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1070        for arg in args[1:]:
1071            if not isinstance(arg, exp.Literal):
1072                # We use the fallback parser because we can't really transpile non-literals safely
1073                return expr_type.from_arg_list(args)
1074
1075            text = arg.name
1076            if is_int(text):
1077                index = int(text)
1078                segments.append(
1079                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1080                )
1081            else:
1082                segments.append(exp.JSONPathKey(this=text))
1083
1084        # This is done to avoid failing in the expression validator due to the arg count
1085        del args[2:]
1086        return expr_type(
1087            this=seq_get(args, 0),
1088            expression=exp.JSONPath(expressions=segments),
1089            only_json_types=arrow_req_json_type,
1090        )
1091
1092    return _builder
1093
1094
1095def json_extract_segments(
1096    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1097) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1098    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1099        path = expression.expression
1100        if not isinstance(path, exp.JSONPath):
1101            return rename_func(name)(self, expression)
1102
1103        segments = []
1104        for segment in path.expressions:
1105            path = self.sql(segment)
1106            if path:
1107                if isinstance(segment, exp.JSONPathPart) and (
1108                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1109                ):
1110                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1111
1112                segments.append(path)
1113
1114        if op:
1115            return f" {op} ".join([self.sql(expression.this), *segments])
1116        return self.func(name, expression.this, *segments)
1117
1118    return _json_extract_segments
1119
1120
1121def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1122    if isinstance(expression.this, exp.JSONPathWildcard):
1123        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1124
1125    return expression.name
1126
1127
1128def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1129    cond = expression.expression
1130    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1131        alias = cond.expressions[0]
1132        cond = cond.this
1133    elif isinstance(cond, exp.Predicate):
1134        alias = "_u"
1135    else:
1136        self.unsupported("Unsupported filter condition")
1137        return ""
1138
1139    unnest = exp.Unnest(expressions=[expression.this])
1140    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1141    return self.sql(exp.Array(expressions=[filtered]))
1142
1143
1144def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1145    return self.func(
1146        "TO_NUMBER",
1147        expression.this,
1148        expression.args.get("format"),
1149        expression.args.get("nlsparam"),
1150    )
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
184class Dialect(metaclass=_Dialect):
185    INDEX_OFFSET = 0
186    """The base index offset for arrays."""
187
188    WEEK_OFFSET = 0
189    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
190
191    UNNEST_COLUMN_ONLY = False
192    """Whether `UNNEST` table aliases are treated as column aliases."""
193
194    ALIAS_POST_TABLESAMPLE = False
195    """Whether the table alias comes after tablesample."""
196
197    TABLESAMPLE_SIZE_IS_PERCENT = False
198    """Whether a size in the table sample clause represents percentage."""
199
200    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
201    """Specifies the strategy according to which identifiers should be normalized."""
202
203    IDENTIFIERS_CAN_START_WITH_DIGIT = False
204    """Whether an unquoted identifier can start with a digit."""
205
206    DPIPE_IS_STRING_CONCAT = True
207    """Whether the DPIPE token (`||`) is a string concatenation operator."""
208
209    STRICT_STRING_CONCAT = False
210    """Whether `CONCAT`'s arguments must be strings."""
211
212    SUPPORTS_USER_DEFINED_TYPES = True
213    """Whether user-defined data types are supported."""
214
215    SUPPORTS_SEMI_ANTI_JOIN = True
216    """Whether `SEMI` or `ANTI` joins are supported."""
217
218    NORMALIZE_FUNCTIONS: bool | str = "upper"
219    """
220    Determines how function names are going to be normalized.
221    Possible values:
222        "upper" or True: Convert names to uppercase.
223        "lower": Convert names to lowercase.
224        False: Disables function name normalization.
225    """
226
227    LOG_BASE_FIRST: t.Optional[bool] = True
228    """
229    Whether the base comes first in the `LOG` function.
230    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
231    """
232
233    NULL_ORDERING = "nulls_are_small"
234    """
235    Default `NULL` ordering method to use if not explicitly set.
236    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
237    """
238
239    TYPED_DIVISION = False
240    """
241    Whether the behavior of `a / b` depends on the types of `a` and `b`.
242    False means `a / b` is always float division.
243    True means `a / b` is integer division if both `a` and `b` are integers.
244    """
245
246    SAFE_DIVISION = False
247    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
248
249    CONCAT_COALESCE = False
250    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
251
252    DATE_FORMAT = "'%Y-%m-%d'"
253    DATEINT_FORMAT = "'%Y%m%d'"
254    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
255
256    TIME_MAPPING: t.Dict[str, str] = {}
257    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
258
259    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
260    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
261    FORMAT_MAPPING: t.Dict[str, str] = {}
262    """
263    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
264    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
265    """
266
267    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
268    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
269
270    PSEUDOCOLUMNS: t.Set[str] = set()
271    """
272    Columns that are auto-generated by the engine corresponding to this dialect.
273    For example, such columns may be excluded from `SELECT *` queries.
274    """
275
276    PREFER_CTE_ALIAS_COLUMN = False
277    """
278    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
279    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
280    any projection aliases in the subquery.
281
282    For example,
283        WITH y(c) AS (
284            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
285        ) SELECT c FROM y;
286
287        will be rewritten as
288
289        WITH y(c) AS (
290            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
291        ) SELECT c FROM y;
292    """
293
294    # --- Autofilled ---
295
296    tokenizer_class = Tokenizer
297    parser_class = Parser
298    generator_class = Generator
299
300    # A trie of the time_mapping keys
301    TIME_TRIE: t.Dict = {}
302    FORMAT_TRIE: t.Dict = {}
303
304    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
305    INVERSE_TIME_TRIE: t.Dict = {}
306
307    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
308
309    # Delimiters for string literals and identifiers
310    QUOTE_START = "'"
311    QUOTE_END = "'"
312    IDENTIFIER_START = '"'
313    IDENTIFIER_END = '"'
314
315    # Delimiters for bit, hex, byte and unicode literals
316    BIT_START: t.Optional[str] = None
317    BIT_END: t.Optional[str] = None
318    HEX_START: t.Optional[str] = None
319    HEX_END: t.Optional[str] = None
320    BYTE_START: t.Optional[str] = None
321    BYTE_END: t.Optional[str] = None
322    UNICODE_START: t.Optional[str] = None
323    UNICODE_END: t.Optional[str] = None
324
325    # Separator of COPY statement parameters
326    COPY_PARAMS_ARE_CSV = True
327
328    @classmethod
329    def get_or_raise(cls, dialect: DialectType) -> Dialect:
330        """
331        Look up a dialect in the global dialect registry and return it if it exists.
332
333        Args:
334            dialect: The target dialect. If this is a string, it can be optionally followed by
335                additional key-value pairs that are separated by commas and are used to specify
336                dialect settings, such as whether the dialect's identifiers are case-sensitive.
337
338        Example:
339            >>> dialect = dialect_class = get_or_raise("duckdb")
340            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
341
342        Returns:
343            The corresponding Dialect instance.
344        """
345
346        if not dialect:
347            return cls()
348        if isinstance(dialect, _Dialect):
349            return dialect()
350        if isinstance(dialect, Dialect):
351            return dialect
352        if isinstance(dialect, str):
353            try:
354                dialect_name, *kv_pairs = dialect.split(",")
355                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
356            except ValueError:
357                raise ValueError(
358                    f"Invalid dialect format: '{dialect}'. "
359                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
360                )
361
362            result = cls.get(dialect_name.strip())
363            if not result:
364                from difflib import get_close_matches
365
366                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
367                if similar:
368                    similar = f" Did you mean {similar}?"
369
370                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
371
372            return result(**kwargs)
373
374        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
375
376    @classmethod
377    def format_time(
378        cls, expression: t.Optional[str | exp.Expression]
379    ) -> t.Optional[exp.Expression]:
380        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
381        if isinstance(expression, str):
382            return exp.Literal.string(
383                # the time formats are quoted
384                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
385            )
386
387        if expression and expression.is_string:
388            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
389
390        return expression
391
392    def __init__(self, **kwargs) -> None:
393        normalization_strategy = kwargs.get("normalization_strategy")
394
395        if normalization_strategy is None:
396            self.normalization_strategy = self.NORMALIZATION_STRATEGY
397        else:
398            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
399
400    def __eq__(self, other: t.Any) -> bool:
401        # Does not currently take dialect state into account
402        return type(self) == other
403
404    def __hash__(self) -> int:
405        # Does not currently take dialect state into account
406        return hash(type(self))
407
408    def normalize_identifier(self, expression: E) -> E:
409        """
410        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
411
412        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
413        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
414        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
415        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
416
417        There are also dialects like Spark, which are case-insensitive even when quotes are
418        present, and dialects like MySQL, whose resolution rules match those employed by the
419        underlying operating system, for example they may always be case-sensitive in Linux.
420
421        Finally, the normalization behavior of some engines can even be controlled through flags,
422        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
423
424        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
425        that it can analyze queries in the optimizer and successfully capture their semantics.
426        """
427        if (
428            isinstance(expression, exp.Identifier)
429            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
430            and (
431                not expression.quoted
432                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
433            )
434        ):
435            expression.set(
436                "this",
437                (
438                    expression.this.upper()
439                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
440                    else expression.this.lower()
441                ),
442            )
443
444        return expression
445
446    def case_sensitive(self, text: str) -> bool:
447        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
448        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
449            return False
450
451        unsafe = (
452            str.islower
453            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
454            else str.isupper
455        )
456        return any(unsafe(char) for char in text)
457
458    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
459        """Checks if text can be identified given an identify option.
460
461        Args:
462            text: The text to check.
463            identify:
464                `"always"` or `True`: Always returns `True`.
465                `"safe"`: Only returns `True` if the identifier is case-insensitive.
466
467        Returns:
468            Whether the given text can be identified.
469        """
470        if identify is True or identify == "always":
471            return True
472
473        if identify == "safe":
474            return not self.case_sensitive(text)
475
476        return False
477
478    def quote_identifier(self, expression: E, identify: bool = True) -> E:
479        """
480        Adds quotes to a given identifier.
481
482        Args:
483            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
484            identify: If set to `False`, the quotes will only be added if the identifier is deemed
485                "unsafe", with respect to its characters and this dialect's normalization strategy.
486        """
487        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
488            name = expression.this
489            expression.set(
490                "quoted",
491                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
492            )
493
494        return expression
495
496    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
497        if isinstance(path, exp.Literal):
498            path_text = path.name
499            if path.is_number:
500                path_text = f"[{path_text}]"
501
502            try:
503                return parse_json_path(path_text)
504            except ParseError as e:
505                logger.warning(f"Invalid JSON path syntax. {str(e)}")
506
507        return path
508
509    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
510        return self.parser(**opts).parse(self.tokenize(sql), sql)
511
512    def parse_into(
513        self, expression_type: exp.IntoType, sql: str, **opts
514    ) -> t.List[t.Optional[exp.Expression]]:
515        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
516
517    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
518        return self.generator(**opts).generate(expression, copy=copy)
519
520    def transpile(self, sql: str, **opts) -> t.List[str]:
521        return [
522            self.generate(expression, copy=False, **opts) if expression else ""
523            for expression in self.parse(sql)
524        ]
525
526    def tokenize(self, sql: str) -> t.List[Token]:
527        return self.tokenizer.tokenize(sql)
528
529    @property
530    def tokenizer(self) -> Tokenizer:
531        if not hasattr(self, "_tokenizer"):
532            self._tokenizer = self.tokenizer_class(dialect=self)
533        return self._tokenizer
534
535    def parser(self, **opts) -> Parser:
536        return self.parser_class(dialect=self, **opts)
537
538    def generator(self, **opts) -> Generator:
539        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
392    def __init__(self, **kwargs) -> None:
393        normalization_strategy = kwargs.get("normalization_strategy")
394
395        if normalization_strategy is None:
396            self.normalization_strategy = self.NORMALIZATION_STRATEGY
397        else:
398            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
COPY_PARAMS_ARE_CSV = True
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
328    @classmethod
329    def get_or_raise(cls, dialect: DialectType) -> Dialect:
330        """
331        Look up a dialect in the global dialect registry and return it if it exists.
332
333        Args:
334            dialect: The target dialect. If this is a string, it can be optionally followed by
335                additional key-value pairs that are separated by commas and are used to specify
336                dialect settings, such as whether the dialect's identifiers are case-sensitive.
337
338        Example:
339            >>> dialect = dialect_class = get_or_raise("duckdb")
340            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
341
342        Returns:
343            The corresponding Dialect instance.
344        """
345
346        if not dialect:
347            return cls()
348        if isinstance(dialect, _Dialect):
349            return dialect()
350        if isinstance(dialect, Dialect):
351            return dialect
352        if isinstance(dialect, str):
353            try:
354                dialect_name, *kv_pairs = dialect.split(",")
355                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
356            except ValueError:
357                raise ValueError(
358                    f"Invalid dialect format: '{dialect}'. "
359                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
360                )
361
362            result = cls.get(dialect_name.strip())
363            if not result:
364                from difflib import get_close_matches
365
366                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
367                if similar:
368                    similar = f" Did you mean {similar}?"
369
370                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
371
372            return result(**kwargs)
373
374        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
376    @classmethod
377    def format_time(
378        cls, expression: t.Optional[str | exp.Expression]
379    ) -> t.Optional[exp.Expression]:
380        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
381        if isinstance(expression, str):
382            return exp.Literal.string(
383                # the time formats are quoted
384                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
385            )
386
387        if expression and expression.is_string:
388            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
389
390        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
408    def normalize_identifier(self, expression: E) -> E:
409        """
410        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
411
412        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
413        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
414        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
415        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
416
417        There are also dialects like Spark, which are case-insensitive even when quotes are
418        present, and dialects like MySQL, whose resolution rules match those employed by the
419        underlying operating system, for example they may always be case-sensitive in Linux.
420
421        Finally, the normalization behavior of some engines can even be controlled through flags,
422        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
423
424        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
425        that it can analyze queries in the optimizer and successfully capture their semantics.
426        """
427        if (
428            isinstance(expression, exp.Identifier)
429            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
430            and (
431                not expression.quoted
432                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
433            )
434        ):
435            expression.set(
436                "this",
437                (
438                    expression.this.upper()
439                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
440                    else expression.this.lower()
441                ),
442            )
443
444        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
446    def case_sensitive(self, text: str) -> bool:
447        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
448        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
449            return False
450
451        unsafe = (
452            str.islower
453            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
454            else str.isupper
455        )
456        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
458    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
459        """Checks if text can be identified given an identify option.
460
461        Args:
462            text: The text to check.
463            identify:
464                `"always"` or `True`: Always returns `True`.
465                `"safe"`: Only returns `True` if the identifier is case-insensitive.
466
467        Returns:
468            Whether the given text can be identified.
469        """
470        if identify is True or identify == "always":
471            return True
472
473        if identify == "safe":
474            return not self.case_sensitive(text)
475
476        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
478    def quote_identifier(self, expression: E, identify: bool = True) -> E:
479        """
480        Adds quotes to a given identifier.
481
482        Args:
483            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
484            identify: If set to `False`, the quotes will only be added if the identifier is deemed
485                "unsafe", with respect to its characters and this dialect's normalization strategy.
486        """
487        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
488            name = expression.this
489            expression.set(
490                "quoted",
491                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
492            )
493
494        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
496    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
497        if isinstance(path, exp.Literal):
498            path_text = path.name
499            if path.is_number:
500                path_text = f"[{path_text}]"
501
502            try:
503                return parse_json_path(path_text)
504            except ParseError as e:
505                logger.warning(f"Invalid JSON path syntax. {str(e)}")
506
507        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
509    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
510        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
512    def parse_into(
513        self, expression_type: exp.IntoType, sql: str, **opts
514    ) -> t.List[t.Optional[exp.Expression]]:
515        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
517    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
518        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
520    def transpile(self, sql: str, **opts) -> t.List[str]:
521        return [
522            self.generate(expression, copy=False, **opts) if expression else ""
523            for expression in self.parse(sql)
524        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
526    def tokenize(self, sql: str) -> t.List[Token]:
527        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
529    @property
530    def tokenizer(self) -> Tokenizer:
531        if not hasattr(self, "_tokenizer"):
532            self._tokenizer = self.tokenizer_class(dialect=self)
533        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
535    def parser(self, **opts) -> Parser:
536        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
538    def generator(self, **opts) -> Generator:
539        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
545def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
546    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
549def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
550    if expression.args.get("accuracy"):
551        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
552    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
555def if_sql(
556    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
557) -> t.Callable[[Generator, exp.If], str]:
558    def _if_sql(self: Generator, expression: exp.If) -> str:
559        return self.func(
560            name,
561            expression.this,
562            expression.args.get("true"),
563            expression.args.get("false") or false_value,
564        )
565
566    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
569def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
570    this = expression.this
571    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
572        this.replace(exp.cast(this, exp.DataType.Type.JSON))
573
574    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
577def inline_array_sql(self: Generator, expression: exp.Array) -> str:
578    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
581def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
582    elem = seq_get(expression.expressions, 0)
583    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
584        return self.func("ARRAY", elem)
585    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
588def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
589    return self.like_sql(
590        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
591    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
594def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
595    zone = self.sql(expression, "this")
596    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
599def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
600    if expression.args.get("recursive"):
601        self.unsupported("Recursive CTEs are unsupported")
602        expression.args["recursive"] = False
603    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
606def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
607    n = self.sql(expression, "this")
608    d = self.sql(expression, "expression")
609    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
612def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
613    self.unsupported("TABLESAMPLE unsupported")
614    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
617def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
618    self.unsupported("PIVOT unsupported")
619    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
622def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
623    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
626def no_comment_column_constraint_sql(
627    self: Generator, expression: exp.CommentColumnConstraint
628) -> str:
629    self.unsupported("CommentColumnConstraint unsupported")
630    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
633def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
634    self.unsupported("MAP_FROM_ENTRIES unsupported")
635    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
638def str_position_sql(
639    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
640) -> str:
641    this = self.sql(expression, "this")
642    substr = self.sql(expression, "substr")
643    position = self.sql(expression, "position")
644    instance = expression.args.get("instance") if generate_instance else None
645    position_offset = ""
646
647    if position:
648        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
649        this = self.func("SUBSTR", this, position)
650        position_offset = f" + {position} - 1"
651
652    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
655def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
656    return (
657        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
658    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
661def var_map_sql(
662    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
663) -> str:
664    keys = expression.args["keys"]
665    values = expression.args["values"]
666
667    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
668        self.unsupported("Cannot convert array columns into map.")
669        return self.func(map_func_name, keys, values)
670
671    args = []
672    for key, value in zip(keys.expressions, values.expressions):
673        args.append(self.sql(key))
674        args.append(self.sql(value))
675
676    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
679def build_formatted_time(
680    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
681) -> t.Callable[[t.List], E]:
682    """Helper used for time expressions.
683
684    Args:
685        exp_class: the expression class to instantiate.
686        dialect: target sql dialect.
687        default: the default format, True being time.
688
689    Returns:
690        A callable that can be used to return the appropriately formatted time expression.
691    """
692
693    def _builder(args: t.List):
694        return exp_class(
695            this=seq_get(args, 0),
696            format=Dialect[dialect].format_time(
697                seq_get(args, 1)
698                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
699            ),
700        )
701
702    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
705def time_format(
706    dialect: DialectType = None,
707) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
708    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
709        """
710        Returns the time format for a given expression, unless it's equivalent
711        to the default time format of the dialect of interest.
712        """
713        time_format = self.format_time(expression)
714        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
715
716    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
719def build_date_delta(
720    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
721) -> t.Callable[[t.List], E]:
722    def _builder(args: t.List) -> E:
723        unit_based = len(args) == 3
724        this = args[2] if unit_based else seq_get(args, 0)
725        unit = args[0] if unit_based else exp.Literal.string("DAY")
726        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
727        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
728
729    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
732def build_date_delta_with_interval(
733    expression_class: t.Type[E],
734) -> t.Callable[[t.List], t.Optional[E]]:
735    def _builder(args: t.List) -> t.Optional[E]:
736        if len(args) < 2:
737            return None
738
739        interval = args[1]
740
741        if not isinstance(interval, exp.Interval):
742            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
743
744        expression = interval.this
745        if expression and expression.is_string:
746            expression = exp.Literal.number(expression.this)
747
748        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
749
750    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
753def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
754    unit = seq_get(args, 0)
755    this = seq_get(args, 1)
756
757    if isinstance(this, exp.Cast) and this.is_type("date"):
758        return exp.DateTrunc(unit=unit, this=this)
759    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
762def date_add_interval_sql(
763    data_type: str, kind: str
764) -> t.Callable[[Generator, exp.Expression], str]:
765    def func(self: Generator, expression: exp.Expression) -> str:
766        this = self.sql(expression, "this")
767        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
768        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
769
770    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
773def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
774    return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
777def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
778    if not expression.expression:
779        from sqlglot.optimizer.annotate_types import annotate_types
780
781        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
782        return self.sql(exp.cast(expression.this, target_type))
783    if expression.text("expression").lower() in TIMEZONES:
784        return self.sql(
785            exp.AtTimeZone(
786                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
787                zone=expression.expression,
788            )
789        )
790    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
793def locate_to_strposition(args: t.List) -> exp.Expression:
794    return exp.StrPosition(
795        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
796    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
799def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
800    return self.func(
801        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
802    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
805def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
806    return self.sql(
807        exp.Substring(
808            this=expression.this, start=exp.Literal.number(1), length=expression.expression
809        )
810    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
813def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
814    return self.sql(
815        exp.Substring(
816            this=expression.this,
817            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
818        )
819    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
822def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
823    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
826def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
827    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
831def encode_decode_sql(
832    self: Generator, expression: exp.Expression, name: str, replace: bool = True
833) -> str:
834    charset = expression.args.get("charset")
835    if charset and charset.name.lower() != "utf-8":
836        self.unsupported(f"Expected utf-8 character set, got {charset}.")
837
838    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
841def min_or_least(self: Generator, expression: exp.Min) -> str:
842    name = "LEAST" if expression.expressions else "MIN"
843    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
846def max_or_greatest(self: Generator, expression: exp.Max) -> str:
847    name = "GREATEST" if expression.expressions else "MAX"
848    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
851def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
852    cond = expression.this
853
854    if isinstance(expression.this, exp.Distinct):
855        cond = expression.this.expressions[0]
856        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
857
858    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
861def trim_sql(self: Generator, expression: exp.Trim) -> str:
862    target = self.sql(expression, "this")
863    trim_type = self.sql(expression, "position")
864    remove_chars = self.sql(expression, "expression")
865    collation = self.sql(expression, "collation")
866
867    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
868    if not remove_chars and not collation:
869        return self.trim_sql(expression)
870
871    trim_type = f"{trim_type} " if trim_type else ""
872    remove_chars = f"{remove_chars} " if remove_chars else ""
873    from_part = "FROM " if trim_type or remove_chars else ""
874    collation = f" COLLATE {collation}" if collation else ""
875    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
878def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
879    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
882def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
883    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
886def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
887    delim, *rest_args = expression.expressions
888    return self.sql(
889        reduce(
890            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
891            rest_args,
892        )
893    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
896def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
897    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
898    if bad_args:
899        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
900
901    return self.func(
902        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
903    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
906def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
907    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
908    if bad_args:
909        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
910
911    return self.func(
912        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
913    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
916def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
917    names = []
918    for agg in aggregations:
919        if isinstance(agg, exp.Alias):
920            names.append(agg.alias)
921        else:
922            """
923            This case corresponds to aggregations without aliases being used as suffixes
924            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
925            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
926            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
927            """
928            agg_all_unquoted = agg.transform(
929                lambda node: (
930                    exp.Identifier(this=node.name, quoted=False)
931                    if isinstance(node, exp.Identifier)
932                    else node
933                )
934            )
935            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
936
937    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
940def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
941    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
945def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
946    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
949def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
950    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
953def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
954    a = self.sql(expression.left)
955    b = self.sql(expression.right)
956    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
959def is_parse_json(expression: exp.Expression) -> bool:
960    return isinstance(expression, exp.ParseJSON) or (
961        isinstance(expression, exp.Cast) and expression.is_type("json")
962    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
965def isnull_to_is_null(args: t.List) -> exp.Expression:
966    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
969def generatedasidentitycolumnconstraint_sql(
970    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
971) -> str:
972    start = self.sql(expression, "start") or "1"
973    increment = self.sql(expression, "increment") or "1"
974    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
977def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
978    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
979        if expression.args.get("count"):
980            self.unsupported(f"Only two arguments are supported in function {name}.")
981
982        return self.func(name, expression.this, expression.expression)
983
984    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
987def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
988    this = expression.this.copy()
989
990    return_type = expression.return_type
991    if return_type.is_type(exp.DataType.Type.DATE):
992        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
993        # can truncate timestamp strings, because some dialects can't cast them to DATE
994        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
995
996    expression.this.replace(exp.cast(this, return_type))
997    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1000def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1001    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1002        if cast and isinstance(expression, exp.TsOrDsAdd):
1003            expression = ts_or_ds_add_cast(expression)
1004
1005        return self.func(
1006            name,
1007            unit_to_var(expression),
1008            expression.expression,
1009            expression.this,
1010        )
1011
1012    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1015def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1016    unit = expression.args.get("unit")
1017
1018    if isinstance(unit, exp.Placeholder):
1019        return unit
1020    if unit:
1021        return exp.Literal.string(unit.name)
1022    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1025def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1026    unit = expression.args.get("unit")
1027
1028    if isinstance(unit, (exp.Var, exp.Placeholder)):
1029        return unit
1030    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1033def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1034    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1035    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1036    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1037
1038    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1041def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1042    """Remove table refs from columns in when statements."""
1043    alias = expression.this.args.get("alias")
1044
1045    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1046        return self.dialect.normalize_identifier(identifier).name if identifier else None
1047
1048    targets = {normalize(expression.this.this)}
1049
1050    if alias:
1051        targets.add(normalize(alias.this))
1052
1053    for when in expression.expressions:
1054        when.transform(
1055            lambda node: (
1056                exp.column(node.this)
1057                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1058                else node
1059            ),
1060            copy=False,
1061        )
1062
1063    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1066def build_json_extract_path(
1067    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1068) -> t.Callable[[t.List], F]:
1069    def _builder(args: t.List) -> F:
1070        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1071        for arg in args[1:]:
1072            if not isinstance(arg, exp.Literal):
1073                # We use the fallback parser because we can't really transpile non-literals safely
1074                return expr_type.from_arg_list(args)
1075
1076            text = arg.name
1077            if is_int(text):
1078                index = int(text)
1079                segments.append(
1080                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1081                )
1082            else:
1083                segments.append(exp.JSONPathKey(this=text))
1084
1085        # This is done to avoid failing in the expression validator due to the arg count
1086        del args[2:]
1087        return expr_type(
1088            this=seq_get(args, 0),
1089            expression=exp.JSONPath(expressions=segments),
1090            only_json_types=arrow_req_json_type,
1091        )
1092
1093    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1096def json_extract_segments(
1097    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1098) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1099    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1100        path = expression.expression
1101        if not isinstance(path, exp.JSONPath):
1102            return rename_func(name)(self, expression)
1103
1104        segments = []
1105        for segment in path.expressions:
1106            path = self.sql(segment)
1107            if path:
1108                if isinstance(segment, exp.JSONPathPart) and (
1109                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1110                ):
1111                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1112
1113                segments.append(path)
1114
1115        if op:
1116            return f" {op} ".join([self.sql(expression.this), *segments])
1117        return self.func(name, expression.this, *segments)
1118
1119    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1122def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1123    if isinstance(expression.this, exp.JSONPathWildcard):
1124        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1125
1126    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1129def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1130    cond = expression.expression
1131    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1132        alias = cond.expressions[0]
1133        cond = cond.this
1134    elif isinstance(cond, exp.Predicate):
1135        alias = "_u"
1136    else:
1137        self.unsupported("Unsupported filter condition")
1138        return ""
1139
1140    unnest = exp.Unnest(expressions=[expression.this])
1141    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1142    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: sqlglot.expressions.ToNumber) -> str:
1145def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1146    return self.func(
1147        "TO_NUMBER",
1148        expression.this,
1149        expression.args.get("format"),
1150        expression.args.get("nlsparam"),
1151    )