Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23logger = logging.getLogger("sqlglot")
  24
  25# Final means that an expression should not be simplified
  26FINAL = "final"
  27
  28# Value ranges for byte-sized signed/unsigned integers
  29TINYINT_MIN = -128
  30TINYINT_MAX = 127
  31UTINYINT_MIN = 0
  32UTINYINT_MAX = 255
  33
  34
  35class UnsupportedUnit(Exception):
  36    pass
  37
  38
  39def simplify(
  40    expression: exp.Expression,
  41    constant_propagation: bool = False,
  42    dialect: DialectType = None,
  43    max_depth: t.Optional[int] = None,
  44):
  45    """
  46    Rewrite sqlglot AST to simplify expressions.
  47
  48    Example:
  49        >>> import sqlglot
  50        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  51        >>> simplify(expression).sql()
  52        'TRUE'
  53
  54    Args:
  55        expression: expression to simplify
  56        constant_propagation: whether the constant propagation rule should be used
  57        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
  58    Returns:
  59        sqlglot.Expression: simplified expression
  60    """
  61
  62    dialect = Dialect.get_or_raise(dialect)
  63
  64    def _simplify(expression, root=True):
  65        if (
  66            max_depth
  67            and isinstance(expression, exp.Connector)
  68            and not isinstance(expression.parent, exp.Connector)
  69        ):
  70            depth = connector_depth(expression)
  71            if depth > max_depth:
  72                logger.info(
  73                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
  74                )
  75                return expression
  76
  77        if expression.meta.get(FINAL):
  78            return expression
  79
  80        # group by expressions cannot be simplified, for example
  81        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  82        # the projection must exactly match the group by key
  83        group = expression.args.get("group")
  84
  85        if group and hasattr(expression, "selects"):
  86            groups = set(group.expressions)
  87            group.meta[FINAL] = True
  88
  89            for e in expression.selects:
  90                for node in e.walk():
  91                    if node in groups:
  92                        e.meta[FINAL] = True
  93                        break
  94
  95            having = expression.args.get("having")
  96            if having:
  97                for node in having.walk():
  98                    if node in groups:
  99                        having.meta[FINAL] = True
 100                        break
 101
 102        # Pre-order transformations
 103        node = expression
 104        node = rewrite_between(node)
 105        node = uniq_sort(node, root)
 106        node = absorb_and_eliminate(node, root)
 107        node = simplify_concat(node)
 108        node = simplify_conditionals(node)
 109
 110        if constant_propagation:
 111            node = propagate_constants(node, root)
 112
 113        exp.replace_children(node, lambda e: _simplify(e, False))
 114
 115        # Post-order transformations
 116        node = simplify_not(node)
 117        node = flatten(node)
 118        node = simplify_connectors(node, root)
 119        node = remove_complements(node, root)
 120        node = simplify_coalesce(node, dialect)
 121        node.parent = expression.parent
 122        node = simplify_literals(node, root)
 123        node = simplify_equality(node)
 124        node = simplify_parens(node)
 125        node = simplify_datetrunc(node, dialect)
 126        node = sort_comparison(node)
 127        node = simplify_startswith(node)
 128
 129        if root:
 130            expression.replace(node)
 131        return node
 132
 133    expression = while_changing(expression, _simplify)
 134    remove_where_true(expression)
 135    return expression
 136
 137
 138def connector_depth(expression: exp.Expression) -> int:
 139    """
 140    Determine the maximum depth of a tree of Connectors.
 141
 142    For example:
 143        >>> from sqlglot import parse_one
 144        >>> connector_depth(parse_one("a AND b AND c AND d"))
 145        3
 146    """
 147    stack = deque([(expression, 0)])
 148    max_depth = 0
 149
 150    while stack:
 151        expression, depth = stack.pop()
 152
 153        if not isinstance(expression, exp.Connector):
 154            continue
 155
 156        depth += 1
 157        max_depth = max(depth, max_depth)
 158
 159        stack.append((expression.left, depth))
 160        stack.append((expression.right, depth))
 161
 162    return max_depth
 163
 164
 165def catch(*exceptions):
 166    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 167
 168    def decorator(func):
 169        def wrapped(expression, *args, **kwargs):
 170            try:
 171                return func(expression, *args, **kwargs)
 172            except exceptions:
 173                return expression
 174
 175        return wrapped
 176
 177    return decorator
 178
 179
 180def rewrite_between(expression: exp.Expression) -> exp.Expression:
 181    """Rewrite x between y and z to x >= y AND x <= z.
 182
 183    This is done because comparison simplification is only done on lt/lte/gt/gte.
 184    """
 185    if isinstance(expression, exp.Between):
 186        negate = isinstance(expression.parent, exp.Not)
 187
 188        expression = exp.and_(
 189            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 190            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 191            copy=False,
 192        )
 193
 194        if negate:
 195            expression = exp.paren(expression, copy=False)
 196
 197    return expression
 198
 199
 200COMPLEMENT_COMPARISONS = {
 201    exp.LT: exp.GTE,
 202    exp.GT: exp.LTE,
 203    exp.LTE: exp.GT,
 204    exp.GTE: exp.LT,
 205    exp.EQ: exp.NEQ,
 206    exp.NEQ: exp.EQ,
 207}
 208
 209COMPLEMENT_SUBQUERY_PREDICATES = {
 210    exp.All: exp.Any,
 211    exp.Any: exp.All,
 212}
 213
 214
 215def simplify_not(expression):
 216    """
 217    Demorgan's Law
 218    NOT (x OR y) -> NOT x AND NOT y
 219    NOT (x AND y) -> NOT x OR NOT y
 220    """
 221    if isinstance(expression, exp.Not):
 222        this = expression.this
 223        if is_null(this):
 224            return exp.null()
 225        if this.__class__ in COMPLEMENT_COMPARISONS:
 226            right = this.expression
 227            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
 228            if complement_subquery_predicate:
 229                right = complement_subquery_predicate(this=right.this)
 230
 231            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 232        if isinstance(this, exp.Paren):
 233            condition = this.unnest()
 234            if isinstance(condition, exp.And):
 235                return exp.paren(
 236                    exp.or_(
 237                        exp.not_(condition.left, copy=False),
 238                        exp.not_(condition.right, copy=False),
 239                        copy=False,
 240                    )
 241                )
 242            if isinstance(condition, exp.Or):
 243                return exp.paren(
 244                    exp.and_(
 245                        exp.not_(condition.left, copy=False),
 246                        exp.not_(condition.right, copy=False),
 247                        copy=False,
 248                    )
 249                )
 250            if is_null(condition):
 251                return exp.null()
 252        if always_true(this):
 253            return exp.false()
 254        if is_false(this):
 255            return exp.true()
 256        if isinstance(this, exp.Not):
 257            # double negation
 258            # NOT NOT x -> x
 259            return this.this
 260    return expression
 261
 262
 263def flatten(expression):
 264    """
 265    A AND (B AND C) -> A AND B AND C
 266    A OR (B OR C) -> A OR B OR C
 267    """
 268    if isinstance(expression, exp.Connector):
 269        for node in expression.args.values():
 270            child = node.unnest()
 271            if isinstance(child, expression.__class__):
 272                node.replace(child)
 273    return expression
 274
 275
 276def simplify_connectors(expression, root=True):
 277    def _simplify_connectors(expression, left, right):
 278        if isinstance(expression, exp.And):
 279            if is_false(left) or is_false(right):
 280                return exp.false()
 281            if is_zero(left) or is_zero(right):
 282                return exp.false()
 283            if is_null(left) or is_null(right):
 284                return exp.null()
 285            if always_true(left) and always_true(right):
 286                return exp.true()
 287            if always_true(left):
 288                return right
 289            if always_true(right):
 290                return left
 291            return _simplify_comparison(expression, left, right)
 292        elif isinstance(expression, exp.Or):
 293            if always_true(left) or always_true(right):
 294                return exp.true()
 295            if (
 296                (is_null(left) and is_null(right))
 297                or (is_null(left) and always_false(right))
 298                or (always_false(left) and is_null(right))
 299            ):
 300                return exp.null()
 301            if is_false(left):
 302                return right
 303            if is_false(right):
 304                return left
 305            return _simplify_comparison(expression, left, right, or_=True)
 306        elif isinstance(expression, exp.Xor):
 307            if left == right:
 308                return exp.false()
 309
 310    if isinstance(expression, exp.Connector):
 311        return _flat_simplify(expression, _simplify_connectors, root)
 312    return expression
 313
 314
 315LT_LTE = (exp.LT, exp.LTE)
 316GT_GTE = (exp.GT, exp.GTE)
 317
 318COMPARISONS = (
 319    *LT_LTE,
 320    *GT_GTE,
 321    exp.EQ,
 322    exp.NEQ,
 323    exp.Is,
 324)
 325
 326INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 327    exp.LT: exp.GT,
 328    exp.GT: exp.LT,
 329    exp.LTE: exp.GTE,
 330    exp.GTE: exp.LTE,
 331}
 332
 333NONDETERMINISTIC = (exp.Rand, exp.Randn)
 334AND_OR = (exp.And, exp.Or)
 335
 336
 337def _simplify_comparison(expression, left, right, or_=False):
 338    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 339        ll, lr = left.args.values()
 340        rl, rr = right.args.values()
 341
 342        largs = {ll, lr}
 343        rargs = {rl, rr}
 344
 345        matching = largs & rargs
 346        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 347
 348        if matching and columns:
 349            try:
 350                l = first(largs - columns)
 351                r = first(rargs - columns)
 352            except StopIteration:
 353                return expression
 354
 355            if l.is_number and r.is_number:
 356                l = l.to_py()
 357                r = r.to_py()
 358            elif l.is_string and r.is_string:
 359                l = l.name
 360                r = r.name
 361            else:
 362                l = extract_date(l)
 363                if not l:
 364                    return None
 365                r = extract_date(r)
 366                if not r:
 367                    return None
 368                # python won't compare date and datetime, but many engines will upcast
 369                l, r = cast_as_datetime(l), cast_as_datetime(r)
 370
 371            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 372                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 373                    return left if (av > bv if or_ else av <= bv) else right
 374                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 375                    return left if (av < bv if or_ else av >= bv) else right
 376
 377                # we can't ever shortcut to true because the column could be null
 378                if not or_:
 379                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 380                        if av <= bv:
 381                            return exp.false()
 382                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 383                        if av >= bv:
 384                            return exp.false()
 385                    elif isinstance(a, exp.EQ):
 386                        if isinstance(b, exp.LT):
 387                            return exp.false() if av >= bv else a
 388                        if isinstance(b, exp.LTE):
 389                            return exp.false() if av > bv else a
 390                        if isinstance(b, exp.GT):
 391                            return exp.false() if av <= bv else a
 392                        if isinstance(b, exp.GTE):
 393                            return exp.false() if av < bv else a
 394                        if isinstance(b, exp.NEQ):
 395                            return exp.false() if av == bv else a
 396    return None
 397
 398
 399def remove_complements(expression, root=True):
 400    """
 401    Removing complements.
 402
 403    A AND NOT A -> FALSE
 404    A OR NOT A -> TRUE
 405    """
 406    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 407        ops = set(expression.flatten())
 408        for op in ops:
 409            if isinstance(op, exp.Not) and op.this in ops:
 410                return exp.false() if isinstance(expression, exp.And) else exp.true()
 411
 412    return expression
 413
 414
 415def uniq_sort(expression, root=True):
 416    """
 417    Uniq and sort a connector.
 418
 419    C AND A AND B AND B -> A AND B AND C
 420    """
 421    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 422        flattened = tuple(expression.flatten())
 423
 424        if isinstance(expression, exp.Xor):
 425            result_func = exp.xor
 426            # Do not deduplicate XOR as A XOR A != A if A == True
 427            deduped = None
 428            arr = tuple((gen(e), e) for e in flattened)
 429        else:
 430            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 431            deduped = {gen(e): e for e in flattened}
 432            arr = tuple(deduped.items())
 433
 434        # check if the operands are already sorted, if not sort them
 435        # A AND C AND B -> A AND B AND C
 436        for i, (sql, e) in enumerate(arr[1:]):
 437            if sql < arr[i][0]:
 438                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 439                break
 440        else:
 441            # we didn't have to sort but maybe we need to dedup
 442            if deduped and len(deduped) < len(flattened):
 443                expression = result_func(*deduped.values(), copy=False)
 444
 445    return expression
 446
 447
 448def absorb_and_eliminate(expression, root=True):
 449    """
 450    absorption:
 451        A AND (A OR B) -> A
 452        A OR (A AND B) -> A
 453        A AND (NOT A OR B) -> A AND B
 454        A OR (NOT A AND B) -> A OR B
 455    elimination:
 456        (A AND B) OR (A AND NOT B) -> A
 457        (A OR B) AND (A OR NOT B) -> A
 458    """
 459    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 460        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 461
 462        ops = tuple(expression.flatten())
 463
 464        # Initialize lookup tables:
 465        # Set of all operands, used to find complements for absorption.
 466        op_set = set()
 467        # Sub-operands, used to find subsets for absorption.
 468        subops = defaultdict(list)
 469        # Pairs of complements, used for elimination.
 470        pairs = defaultdict(list)
 471
 472        # Populate the lookup tables
 473        for op in ops:
 474            op_set.add(op)
 475
 476            if not isinstance(op, kind):
 477                # In cases like: A OR (A AND B)
 478                # Subop will be: ^
 479                subops[op].append({op})
 480                continue
 481
 482            # In cases like: (A AND B) OR (A AND B AND C)
 483            # Subops will be: ^     ^
 484            subset = set(op.flatten())
 485            for i in subset:
 486                subops[i].append(subset)
 487
 488            a, b = op.unnest_operands()
 489            if isinstance(a, exp.Not):
 490                pairs[frozenset((a.this, b))].append((op, b))
 491            if isinstance(b, exp.Not):
 492                pairs[frozenset((a, b.this))].append((op, a))
 493
 494        for op in ops:
 495            if not isinstance(op, kind):
 496                continue
 497
 498            a, b = op.unnest_operands()
 499
 500            # Absorb
 501            if isinstance(a, exp.Not) and a.this in op_set:
 502                a.replace(exp.true() if kind == exp.And else exp.false())
 503                continue
 504            if isinstance(b, exp.Not) and b.this in op_set:
 505                b.replace(exp.true() if kind == exp.And else exp.false())
 506                continue
 507            superset = set(op.flatten())
 508            if any(any(subset < superset for subset in subops[i]) for i in superset):
 509                op.replace(exp.false() if kind == exp.And else exp.true())
 510                continue
 511
 512            # Eliminate
 513            for other, complement in pairs[frozenset((a, b))]:
 514                op.replace(complement)
 515                other.replace(complement)
 516
 517    return expression
 518
 519
 520def propagate_constants(expression, root=True):
 521    """
 522    Propagate constants for conjunctions in DNF:
 523
 524    SELECT * FROM t WHERE a = b AND b = 5 becomes
 525    SELECT * FROM t WHERE a = 5 AND b = 5
 526
 527    Reference: https://www.sqlite.org/optoverview.html
 528    """
 529
 530    if (
 531        isinstance(expression, exp.And)
 532        and (root or not expression.same_parent)
 533        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 534    ):
 535        constant_mapping = {}
 536        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 537            if isinstance(expr, exp.EQ):
 538                l, r = expr.left, expr.right
 539
 540                # TODO: create a helper that can be used to detect nested literal expressions such
 541                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 542                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 543                    constant_mapping[l] = (id(l), r)
 544
 545        if constant_mapping:
 546            for column in find_all_in_scope(expression, exp.Column):
 547                parent = column.parent
 548                column_id, constant = constant_mapping.get(column) or (None, None)
 549                if (
 550                    column_id is not None
 551                    and id(column) != column_id
 552                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 553                ):
 554                    column.replace(constant.copy())
 555
 556    return expression
 557
 558
 559INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 560    exp.DateAdd: exp.Sub,
 561    exp.DateSub: exp.Add,
 562    exp.DatetimeAdd: exp.Sub,
 563    exp.DatetimeSub: exp.Add,
 564}
 565
 566INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 567    **INVERSE_DATE_OPS,
 568    exp.Add: exp.Sub,
 569    exp.Sub: exp.Add,
 570}
 571
 572
 573def _is_number(expression: exp.Expression) -> bool:
 574    return expression.is_number
 575
 576
 577def _is_interval(expression: exp.Expression) -> bool:
 578    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 579
 580
 581@catch(ModuleNotFoundError, UnsupportedUnit)
 582def simplify_equality(expression: exp.Expression) -> exp.Expression:
 583    """
 584    Use the subtraction and addition properties of equality to simplify expressions:
 585
 586        x + 1 = 3 becomes x = 2
 587
 588    There are two binary operations in the above expression: + and =
 589    Here's how we reference all the operands in the code below:
 590
 591          l     r
 592        x + 1 = 3
 593        a   b
 594    """
 595    if isinstance(expression, COMPARISONS):
 596        l, r = expression.left, expression.right
 597
 598        if l.__class__ not in INVERSE_OPS:
 599            return expression
 600
 601        if r.is_number:
 602            a_predicate = _is_number
 603            b_predicate = _is_number
 604        elif _is_date_literal(r):
 605            a_predicate = _is_date_literal
 606            b_predicate = _is_interval
 607        else:
 608            return expression
 609
 610        if l.__class__ in INVERSE_DATE_OPS:
 611            l = t.cast(exp.IntervalOp, l)
 612            a = l.this
 613            b = l.interval()
 614        else:
 615            l = t.cast(exp.Binary, l)
 616            a, b = l.left, l.right
 617
 618        if not a_predicate(a) and b_predicate(b):
 619            pass
 620        elif not a_predicate(b) and b_predicate(a):
 621            a, b = b, a
 622        else:
 623            return expression
 624
 625        return expression.__class__(
 626            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 627        )
 628    return expression
 629
 630
 631def simplify_literals(expression, root=True):
 632    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 633        return _flat_simplify(expression, _simplify_binary, root)
 634
 635    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
 636        return expression.this.this
 637
 638    if type(expression) in INVERSE_DATE_OPS:
 639        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 640
 641    return expression
 642
 643
 644NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 645
 646
 647def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 648    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 649        this = _simplify_integer_cast(expr.this)
 650    else:
 651        this = expr.this
 652
 653    if isinstance(expr, exp.Cast) and this.is_int:
 654        num = this.to_py()
 655
 656        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 657        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 658        # engine-dependent
 659        if (
 660            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 661        ) or (
 662            UTINYINT_MIN <= num <= UTINYINT_MAX
 663            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 664        ):
 665            return this
 666
 667    return expr
 668
 669
 670def _simplify_binary(expression, a, b):
 671    if isinstance(expression, COMPARISONS):
 672        a = _simplify_integer_cast(a)
 673        b = _simplify_integer_cast(b)
 674
 675    if isinstance(expression, exp.Is):
 676        if isinstance(b, exp.Not):
 677            c = b.this
 678            not_ = True
 679        else:
 680            c = b
 681            not_ = False
 682
 683        if is_null(c):
 684            if isinstance(a, exp.Literal):
 685                return exp.true() if not_ else exp.false()
 686            if is_null(a):
 687                return exp.false() if not_ else exp.true()
 688    elif isinstance(expression, NULL_OK):
 689        return None
 690    elif is_null(a) or is_null(b):
 691        return exp.null()
 692
 693    if a.is_number and b.is_number:
 694        num_a = a.to_py()
 695        num_b = b.to_py()
 696
 697        if isinstance(expression, exp.Add):
 698            return exp.Literal.number(num_a + num_b)
 699        if isinstance(expression, exp.Mul):
 700            return exp.Literal.number(num_a * num_b)
 701
 702        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 703        if isinstance(expression, exp.Sub):
 704            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 705        if isinstance(expression, exp.Div):
 706            # engines have differing int div behavior so intdiv is not safe
 707            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 708                return None
 709            return exp.Literal.number(num_a / num_b)
 710
 711        boolean = eval_boolean(expression, num_a, num_b)
 712
 713        if boolean:
 714            return boolean
 715    elif a.is_string and b.is_string:
 716        boolean = eval_boolean(expression, a.this, b.this)
 717
 718        if boolean:
 719            return boolean
 720    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 721        date, b = extract_date(a), extract_interval(b)
 722        if date and b:
 723            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 724                return date_literal(date + b, extract_type(a))
 725            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 726                return date_literal(date - b, extract_type(a))
 727    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 728        a, date = extract_interval(a), extract_date(b)
 729        # you cannot subtract a date from an interval
 730        if a and b and isinstance(expression, exp.Add):
 731            return date_literal(a + date, extract_type(b))
 732    elif _is_date_literal(a) and _is_date_literal(b):
 733        if isinstance(expression, exp.Predicate):
 734            a, b = extract_date(a), extract_date(b)
 735            boolean = eval_boolean(expression, a, b)
 736            if boolean:
 737                return boolean
 738
 739    return None
 740
 741
 742def simplify_parens(expression):
 743    if not isinstance(expression, exp.Paren):
 744        return expression
 745
 746    this = expression.this
 747    parent = expression.parent
 748    parent_is_predicate = isinstance(parent, exp.Predicate)
 749
 750    if (
 751        not isinstance(this, exp.Select)
 752        and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
 753        and (
 754            not isinstance(parent, (exp.Condition, exp.Binary))
 755            or isinstance(parent, exp.Paren)
 756            or (
 757                not isinstance(this, exp.Binary)
 758                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 759            )
 760            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 761            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 762            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 763            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 764        )
 765    ):
 766        return this
 767    return expression
 768
 769
 770def _is_nonnull_constant(expression: exp.Expression) -> bool:
 771    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 772
 773
 774def _is_constant(expression: exp.Expression) -> bool:
 775    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 776
 777
 778def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
 779    # COALESCE(x) -> x
 780    if (
 781        isinstance(expression, exp.Coalesce)
 782        and (not expression.expressions or _is_nonnull_constant(expression.this))
 783        # COALESCE is also used as a Spark partitioning hint
 784        and not isinstance(expression.parent, exp.Hint)
 785    ):
 786        return expression.this
 787
 788    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
 789    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
 790    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
 791    if dialect == "redshift":
 792        return expression
 793
 794    if not isinstance(expression, COMPARISONS):
 795        return expression
 796
 797    if isinstance(expression.left, exp.Coalesce):
 798        coalesce = expression.left
 799        other = expression.right
 800    elif isinstance(expression.right, exp.Coalesce):
 801        coalesce = expression.right
 802        other = expression.left
 803    else:
 804        return expression
 805
 806    # This transformation is valid for non-constants,
 807    # but it really only does anything if they are both constants.
 808    if not _is_constant(other):
 809        return expression
 810
 811    # Find the first constant arg
 812    for arg_index, arg in enumerate(coalesce.expressions):
 813        if _is_constant(arg):
 814            break
 815    else:
 816        return expression
 817
 818    coalesce.set("expressions", coalesce.expressions[:arg_index])
 819
 820    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 821    # since we already remove COALESCE at the top of this function.
 822    coalesce = coalesce if coalesce.expressions else coalesce.this
 823
 824    # This expression is more complex than when we started, but it will get simplified further
 825    return exp.paren(
 826        exp.or_(
 827            exp.and_(
 828                coalesce.is_(exp.null()).not_(copy=False),
 829                expression.copy(),
 830                copy=False,
 831            ),
 832            exp.and_(
 833                coalesce.is_(exp.null()),
 834                type(expression)(this=arg.copy(), expression=other.copy()),
 835                copy=False,
 836            ),
 837            copy=False,
 838        )
 839    )
 840
 841
 842CONCATS = (exp.Concat, exp.DPipe)
 843
 844
 845def simplify_concat(expression):
 846    """Reduces all groups that contain string literals by concatenating them."""
 847    if not isinstance(expression, CONCATS) or (
 848        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 849        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 850    ):
 851        return expression
 852
 853    if isinstance(expression, exp.ConcatWs):
 854        sep_expr, *expressions = expression.expressions
 855        sep = sep_expr.name
 856        concat_type = exp.ConcatWs
 857        args = {}
 858    else:
 859        expressions = expression.expressions
 860        sep = ""
 861        concat_type = exp.Concat
 862        args = {
 863            "safe": expression.args.get("safe"),
 864            "coalesce": expression.args.get("coalesce"),
 865        }
 866
 867    new_args = []
 868    for is_string_group, group in itertools.groupby(
 869        expressions or expression.flatten(), lambda e: e.is_string
 870    ):
 871        if is_string_group:
 872            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 873        else:
 874            new_args.extend(group)
 875
 876    if len(new_args) == 1 and new_args[0].is_string:
 877        return new_args[0]
 878
 879    if concat_type is exp.ConcatWs:
 880        new_args = [sep_expr] + new_args
 881    elif isinstance(expression, exp.DPipe):
 882        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 883
 884    return concat_type(expressions=new_args, **args)
 885
 886
 887def simplify_conditionals(expression):
 888    """Simplifies expressions like IF, CASE if their condition is statically known."""
 889    if isinstance(expression, exp.Case):
 890        this = expression.this
 891        for case in expression.args["ifs"]:
 892            cond = case.this
 893            if this:
 894                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 895                cond = cond.replace(this.pop().eq(cond))
 896
 897            if always_true(cond):
 898                return case.args["true"]
 899
 900            if always_false(cond):
 901                case.pop()
 902                if not expression.args["ifs"]:
 903                    return expression.args.get("default") or exp.null()
 904    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 905        if always_true(expression.this):
 906            return expression.args["true"]
 907        if always_false(expression.this):
 908            return expression.args.get("false") or exp.null()
 909
 910    return expression
 911
 912
 913def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 914    """
 915    Reduces a prefix check to either TRUE or FALSE if both the string and the
 916    prefix are statically known.
 917
 918    Example:
 919        >>> from sqlglot import parse_one
 920        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 921        'TRUE'
 922    """
 923    if (
 924        isinstance(expression, exp.StartsWith)
 925        and expression.this.is_string
 926        and expression.expression.is_string
 927    ):
 928        return exp.convert(expression.name.startswith(expression.expression.name))
 929
 930    return expression
 931
 932
 933DateRange = t.Tuple[datetime.date, datetime.date]
 934
 935
 936def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 937    """
 938    Get the date range for a DATE_TRUNC equality comparison:
 939
 940    Example:
 941        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 942    Returns:
 943        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 944    """
 945    floor = date_floor(date, unit, dialect)
 946
 947    if date != floor:
 948        # This will always be False, except for NULL values.
 949        return None
 950
 951    return floor, floor + interval(unit)
 952
 953
 954def _datetrunc_eq_expression(
 955    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 956) -> exp.Expression:
 957    """Get the logical expression for a date range"""
 958    return exp.and_(
 959        left >= date_literal(drange[0], target_type),
 960        left < date_literal(drange[1], target_type),
 961        copy=False,
 962    )
 963
 964
 965def _datetrunc_eq(
 966    left: exp.Expression,
 967    date: datetime.date,
 968    unit: str,
 969    dialect: Dialect,
 970    target_type: t.Optional[exp.DataType],
 971) -> t.Optional[exp.Expression]:
 972    drange = _datetrunc_range(date, unit, dialect)
 973    if not drange:
 974        return None
 975
 976    return _datetrunc_eq_expression(left, drange, target_type)
 977
 978
 979def _datetrunc_neq(
 980    left: exp.Expression,
 981    date: datetime.date,
 982    unit: str,
 983    dialect: Dialect,
 984    target_type: t.Optional[exp.DataType],
 985) -> t.Optional[exp.Expression]:
 986    drange = _datetrunc_range(date, unit, dialect)
 987    if not drange:
 988        return None
 989
 990    return exp.and_(
 991        left < date_literal(drange[0], target_type),
 992        left >= date_literal(drange[1], target_type),
 993        copy=False,
 994    )
 995
 996
 997DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 998    exp.LT: lambda l, dt, u, d, t: l
 999    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
1000    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
1001    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
1002    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
1003    exp.EQ: _datetrunc_eq,
1004    exp.NEQ: _datetrunc_neq,
1005}
1006DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
1007DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
1008
1009
1010def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1011    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1012
1013
1014@catch(ModuleNotFoundError, UnsupportedUnit)
1015def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1016    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1017    comparison = expression.__class__
1018
1019    if isinstance(expression, DATETRUNCS):
1020        this = expression.this
1021        trunc_type = extract_type(this)
1022        date = extract_date(this)
1023        if date and expression.unit:
1024            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1025    elif comparison not in DATETRUNC_COMPARISONS:
1026        return expression
1027
1028    if isinstance(expression, exp.Binary):
1029        l, r = expression.left, expression.right
1030
1031        if not _is_datetrunc_predicate(l, r):
1032            return expression
1033
1034        l = t.cast(exp.DateTrunc, l)
1035        trunc_arg = l.this
1036        unit = l.unit.name.lower()
1037        date = extract_date(r)
1038
1039        if not date:
1040            return expression
1041
1042        return (
1043            DATETRUNC_BINARY_COMPARISONS[comparison](
1044                trunc_arg, date, unit, dialect, extract_type(r)
1045            )
1046            or expression
1047        )
1048
1049    if isinstance(expression, exp.In):
1050        l = expression.this
1051        rs = expression.expressions
1052
1053        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1054            l = t.cast(exp.DateTrunc, l)
1055            unit = l.unit.name.lower()
1056
1057            ranges = []
1058            for r in rs:
1059                date = extract_date(r)
1060                if not date:
1061                    return expression
1062                drange = _datetrunc_range(date, unit, dialect)
1063                if drange:
1064                    ranges.append(drange)
1065
1066            if not ranges:
1067                return expression
1068
1069            ranges = merge_ranges(ranges)
1070            target_type = extract_type(*rs)
1071
1072            return exp.or_(
1073                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1074            )
1075
1076    return expression
1077
1078
1079def sort_comparison(expression: exp.Expression) -> exp.Expression:
1080    if expression.__class__ in COMPLEMENT_COMPARISONS:
1081        l, r = expression.this, expression.expression
1082        l_column = isinstance(l, exp.Column)
1083        r_column = isinstance(r, exp.Column)
1084        l_const = _is_constant(l)
1085        r_const = _is_constant(r)
1086
1087        if (
1088            (l_column and not r_column)
1089            or (r_const and not l_const)
1090            or isinstance(r, exp.SubqueryPredicate)
1091        ):
1092            return expression
1093        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1094            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1095                this=r, expression=l
1096            )
1097    return expression
1098
1099
1100# CROSS joins result in an empty table if the right table is empty.
1101# So we can only simplify certain types of joins to CROSS.
1102# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1103JOINS = {
1104    ("", ""),
1105    ("", "INNER"),
1106    ("RIGHT", ""),
1107    ("RIGHT", "OUTER"),
1108}
1109
1110
1111def remove_where_true(expression):
1112    for where in expression.find_all(exp.Where):
1113        if always_true(where.this):
1114            where.pop()
1115    for join in expression.find_all(exp.Join):
1116        if (
1117            always_true(join.args.get("on"))
1118            and not join.args.get("using")
1119            and not join.args.get("method")
1120            and (join.side, join.kind) in JOINS
1121        ):
1122            join.args["on"].pop()
1123            join.set("side", None)
1124            join.set("kind", "CROSS")
1125
1126
1127def always_true(expression):
1128    return (isinstance(expression, exp.Boolean) and expression.this) or (
1129        isinstance(expression, exp.Literal) and not is_zero(expression)
1130    )
1131
1132
1133def always_false(expression):
1134    return is_false(expression) or is_null(expression) or is_zero(expression)
1135
1136
1137def is_zero(expression):
1138    return isinstance(expression, exp.Literal) and expression.to_py() == 0
1139
1140
1141def is_complement(a, b):
1142    return isinstance(b, exp.Not) and b.this == a
1143
1144
1145def is_false(a: exp.Expression) -> bool:
1146    return type(a) is exp.Boolean and not a.this
1147
1148
1149def is_null(a: exp.Expression) -> bool:
1150    return type(a) is exp.Null
1151
1152
1153def eval_boolean(expression, a, b):
1154    if isinstance(expression, (exp.EQ, exp.Is)):
1155        return boolean_literal(a == b)
1156    if isinstance(expression, exp.NEQ):
1157        return boolean_literal(a != b)
1158    if isinstance(expression, exp.GT):
1159        return boolean_literal(a > b)
1160    if isinstance(expression, exp.GTE):
1161        return boolean_literal(a >= b)
1162    if isinstance(expression, exp.LT):
1163        return boolean_literal(a < b)
1164    if isinstance(expression, exp.LTE):
1165        return boolean_literal(a <= b)
1166    return None
1167
1168
1169def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1170    if isinstance(value, datetime.datetime):
1171        return value.date()
1172    if isinstance(value, datetime.date):
1173        return value
1174    try:
1175        return datetime.datetime.fromisoformat(value).date()
1176    except ValueError:
1177        return None
1178
1179
1180def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1181    if isinstance(value, datetime.datetime):
1182        return value
1183    if isinstance(value, datetime.date):
1184        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1185    try:
1186        return datetime.datetime.fromisoformat(value)
1187    except ValueError:
1188        return None
1189
1190
1191def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1192    if not value:
1193        return None
1194    if to.is_type(exp.DataType.Type.DATE):
1195        return cast_as_date(value)
1196    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1197        return cast_as_datetime(value)
1198    return None
1199
1200
1201def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1202    if isinstance(cast, exp.Cast):
1203        to = cast.to
1204    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1205        to = exp.DataType.build(exp.DataType.Type.DATE)
1206    else:
1207        return None
1208
1209    if isinstance(cast.this, exp.Literal):
1210        value: t.Any = cast.this.name
1211    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1212        value = extract_date(cast.this)
1213    else:
1214        return None
1215    return cast_value(value, to)
1216
1217
1218def _is_date_literal(expression: exp.Expression) -> bool:
1219    return extract_date(expression) is not None
1220
1221
1222def extract_interval(expression):
1223    try:
1224        n = int(expression.this.to_py())
1225        unit = expression.text("unit").lower()
1226        return interval(unit, n)
1227    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1228        return None
1229
1230
1231def extract_type(*expressions):
1232    target_type = None
1233    for expression in expressions:
1234        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1235        if target_type:
1236            break
1237
1238    return target_type
1239
1240
1241def date_literal(date, target_type=None):
1242    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1243        target_type = (
1244            exp.DataType.Type.DATETIME
1245            if isinstance(date, datetime.datetime)
1246            else exp.DataType.Type.DATE
1247        )
1248
1249    return exp.cast(exp.Literal.string(date), target_type)
1250
1251
1252def interval(unit: str, n: int = 1):
1253    from dateutil.relativedelta import relativedelta
1254
1255    if unit == "year":
1256        return relativedelta(years=1 * n)
1257    if unit == "quarter":
1258        return relativedelta(months=3 * n)
1259    if unit == "month":
1260        return relativedelta(months=1 * n)
1261    if unit == "week":
1262        return relativedelta(weeks=1 * n)
1263    if unit == "day":
1264        return relativedelta(days=1 * n)
1265    if unit == "hour":
1266        return relativedelta(hours=1 * n)
1267    if unit == "minute":
1268        return relativedelta(minutes=1 * n)
1269    if unit == "second":
1270        return relativedelta(seconds=1 * n)
1271
1272    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1273
1274
1275def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1276    if unit == "year":
1277        return d.replace(month=1, day=1)
1278    if unit == "quarter":
1279        if d.month <= 3:
1280            return d.replace(month=1, day=1)
1281        elif d.month <= 6:
1282            return d.replace(month=4, day=1)
1283        elif d.month <= 9:
1284            return d.replace(month=7, day=1)
1285        else:
1286            return d.replace(month=10, day=1)
1287    if unit == "month":
1288        return d.replace(month=d.month, day=1)
1289    if unit == "week":
1290        # Assuming week starts on Monday (0) and ends on Sunday (6)
1291        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1292    if unit == "day":
1293        return d
1294
1295    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1296
1297
1298def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1299    floor = date_floor(d, unit, dialect)
1300
1301    if floor == d:
1302        return d
1303
1304    return floor + interval(unit)
1305
1306
1307def boolean_literal(condition):
1308    return exp.true() if condition else exp.false()
1309
1310
1311def _flat_simplify(expression, simplifier, root=True):
1312    if root or not expression.same_parent:
1313        operands = []
1314        queue = deque(expression.flatten(unnest=False))
1315        size = len(queue)
1316
1317        while queue:
1318            a = queue.popleft()
1319
1320            for b in queue:
1321                result = simplifier(expression, a, b)
1322
1323                if result and result is not expression:
1324                    queue.remove(b)
1325                    queue.appendleft(result)
1326                    break
1327            else:
1328                operands.append(a)
1329
1330        if len(operands) < size:
1331            return functools.reduce(
1332                lambda a, b: expression.__class__(this=a, expression=b), operands
1333            )
1334    return expression
1335
1336
1337def gen(expression: t.Any, comments: bool = False) -> str:
1338    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1339
1340    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1341    generator is expensive so we have a bare minimum sql generator here.
1342
1343    Args:
1344        expression: the expression to convert into a SQL string.
1345        comments: whether to include the expression's comments.
1346    """
1347    return Gen().gen(expression, comments=comments)
1348
1349
1350class Gen:
1351    def __init__(self):
1352        self.stack = []
1353        self.sqls = []
1354
1355    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1356        self.stack = [expression]
1357        self.sqls.clear()
1358
1359        while self.stack:
1360            node = self.stack.pop()
1361
1362            if isinstance(node, exp.Expression):
1363                if comments and node.comments:
1364                    self.stack.append(f" /*{','.join(node.comments)}*/")
1365
1366                exp_handler_name = f"{node.key}_sql"
1367
1368                if hasattr(self, exp_handler_name):
1369                    getattr(self, exp_handler_name)(node)
1370                elif isinstance(node, exp.Func):
1371                    self._function(node)
1372                else:
1373                    key = node.key.upper()
1374                    self.stack.append(f"{key} " if self._args(node) else key)
1375            elif type(node) is list:
1376                for n in reversed(node):
1377                    if n is not None:
1378                        self.stack.extend((n, ","))
1379                if node:
1380                    self.stack.pop()
1381            else:
1382                if node is not None:
1383                    self.sqls.append(str(node))
1384
1385        return "".join(self.sqls)
1386
1387    def add_sql(self, e: exp.Add) -> None:
1388        self._binary(e, " + ")
1389
1390    def alias_sql(self, e: exp.Alias) -> None:
1391        self.stack.extend(
1392            (
1393                e.args.get("alias"),
1394                " AS ",
1395                e.args.get("this"),
1396            )
1397        )
1398
1399    def and_sql(self, e: exp.And) -> None:
1400        self._binary(e, " AND ")
1401
1402    def anonymous_sql(self, e: exp.Anonymous) -> None:
1403        this = e.this
1404        if isinstance(this, str):
1405            name = this.upper()
1406        elif isinstance(this, exp.Identifier):
1407            name = this.this
1408            name = f'"{name}"' if this.quoted else name.upper()
1409        else:
1410            raise ValueError(
1411                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1412            )
1413
1414        self.stack.extend(
1415            (
1416                ")",
1417                e.expressions,
1418                "(",
1419                name,
1420            )
1421        )
1422
1423    def between_sql(self, e: exp.Between) -> None:
1424        self.stack.extend(
1425            (
1426                e.args.get("high"),
1427                " AND ",
1428                e.args.get("low"),
1429                " BETWEEN ",
1430                e.this,
1431            )
1432        )
1433
1434    def boolean_sql(self, e: exp.Boolean) -> None:
1435        self.stack.append("TRUE" if e.this else "FALSE")
1436
1437    def bracket_sql(self, e: exp.Bracket) -> None:
1438        self.stack.extend(
1439            (
1440                "]",
1441                e.expressions,
1442                "[",
1443                e.this,
1444            )
1445        )
1446
1447    def column_sql(self, e: exp.Column) -> None:
1448        for p in reversed(e.parts):
1449            self.stack.extend((p, "."))
1450        self.stack.pop()
1451
1452    def datatype_sql(self, e: exp.DataType) -> None:
1453        self._args(e, 1)
1454        self.stack.append(f"{e.this.name} ")
1455
1456    def div_sql(self, e: exp.Div) -> None:
1457        self._binary(e, " / ")
1458
1459    def dot_sql(self, e: exp.Dot) -> None:
1460        self._binary(e, ".")
1461
1462    def eq_sql(self, e: exp.EQ) -> None:
1463        self._binary(e, " = ")
1464
1465    def from_sql(self, e: exp.From) -> None:
1466        self.stack.extend((e.this, "FROM "))
1467
1468    def gt_sql(self, e: exp.GT) -> None:
1469        self._binary(e, " > ")
1470
1471    def gte_sql(self, e: exp.GTE) -> None:
1472        self._binary(e, " >= ")
1473
1474    def identifier_sql(self, e: exp.Identifier) -> None:
1475        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1476
1477    def ilike_sql(self, e: exp.ILike) -> None:
1478        self._binary(e, " ILIKE ")
1479
1480    def in_sql(self, e: exp.In) -> None:
1481        self.stack.append(")")
1482        self._args(e, 1)
1483        self.stack.extend(
1484            (
1485                "(",
1486                " IN ",
1487                e.this,
1488            )
1489        )
1490
1491    def intdiv_sql(self, e: exp.IntDiv) -> None:
1492        self._binary(e, " DIV ")
1493
1494    def is_sql(self, e: exp.Is) -> None:
1495        self._binary(e, " IS ")
1496
1497    def like_sql(self, e: exp.Like) -> None:
1498        self._binary(e, " Like ")
1499
1500    def literal_sql(self, e: exp.Literal) -> None:
1501        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1502
1503    def lt_sql(self, e: exp.LT) -> None:
1504        self._binary(e, " < ")
1505
1506    def lte_sql(self, e: exp.LTE) -> None:
1507        self._binary(e, " <= ")
1508
1509    def mod_sql(self, e: exp.Mod) -> None:
1510        self._binary(e, " % ")
1511
1512    def mul_sql(self, e: exp.Mul) -> None:
1513        self._binary(e, " * ")
1514
1515    def neg_sql(self, e: exp.Neg) -> None:
1516        self._unary(e, "-")
1517
1518    def neq_sql(self, e: exp.NEQ) -> None:
1519        self._binary(e, " <> ")
1520
1521    def not_sql(self, e: exp.Not) -> None:
1522        self._unary(e, "NOT ")
1523
1524    def null_sql(self, e: exp.Null) -> None:
1525        self.stack.append("NULL")
1526
1527    def or_sql(self, e: exp.Or) -> None:
1528        self._binary(e, " OR ")
1529
1530    def paren_sql(self, e: exp.Paren) -> None:
1531        self.stack.extend(
1532            (
1533                ")",
1534                e.this,
1535                "(",
1536            )
1537        )
1538
1539    def sub_sql(self, e: exp.Sub) -> None:
1540        self._binary(e, " - ")
1541
1542    def subquery_sql(self, e: exp.Subquery) -> None:
1543        self._args(e, 2)
1544        alias = e.args.get("alias")
1545        if alias:
1546            self.stack.append(alias)
1547        self.stack.extend((")", e.this, "("))
1548
1549    def table_sql(self, e: exp.Table) -> None:
1550        self._args(e, 4)
1551        alias = e.args.get("alias")
1552        if alias:
1553            self.stack.append(alias)
1554        for p in reversed(e.parts):
1555            self.stack.extend((p, "."))
1556        self.stack.pop()
1557
1558    def tablealias_sql(self, e: exp.TableAlias) -> None:
1559        columns = e.columns
1560
1561        if columns:
1562            self.stack.extend((")", columns, "("))
1563
1564        self.stack.extend((e.this, " AS "))
1565
1566    def var_sql(self, e: exp.Var) -> None:
1567        self.stack.append(e.this)
1568
1569    def _binary(self, e: exp.Binary, op: str) -> None:
1570        self.stack.extend((e.expression, op, e.this))
1571
1572    def _unary(self, e: exp.Unary, op: str) -> None:
1573        self.stack.extend((e.this, op))
1574
1575    def _function(self, e: exp.Func) -> None:
1576        self.stack.extend(
1577            (
1578                ")",
1579                list(e.args.values()),
1580                "(",
1581                e.sql_name(),
1582            )
1583        )
1584
1585    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1586        kvs = []
1587        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1588
1589        for k in arg_types or arg_types:
1590            v = node.args.get(k)
1591
1592            if v is not None:
1593                kvs.append([f":{k}", v])
1594        if kvs:
1595            self.stack.append(kvs)
1596            return True
1597        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
36class UnsupportedUnit(Exception):
37    pass

Common base class for all non-exit exceptions.

def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, max_depth: Optional[int] = None):
 40def simplify(
 41    expression: exp.Expression,
 42    constant_propagation: bool = False,
 43    dialect: DialectType = None,
 44    max_depth: t.Optional[int] = None,
 45):
 46    """
 47    Rewrite sqlglot AST to simplify expressions.
 48
 49    Example:
 50        >>> import sqlglot
 51        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 52        >>> simplify(expression).sql()
 53        'TRUE'
 54
 55    Args:
 56        expression: expression to simplify
 57        constant_propagation: whether the constant propagation rule should be used
 58        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
 59    Returns:
 60        sqlglot.Expression: simplified expression
 61    """
 62
 63    dialect = Dialect.get_or_raise(dialect)
 64
 65    def _simplify(expression, root=True):
 66        if (
 67            max_depth
 68            and isinstance(expression, exp.Connector)
 69            and not isinstance(expression.parent, exp.Connector)
 70        ):
 71            depth = connector_depth(expression)
 72            if depth > max_depth:
 73                logger.info(
 74                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
 75                )
 76                return expression
 77
 78        if expression.meta.get(FINAL):
 79            return expression
 80
 81        # group by expressions cannot be simplified, for example
 82        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 83        # the projection must exactly match the group by key
 84        group = expression.args.get("group")
 85
 86        if group and hasattr(expression, "selects"):
 87            groups = set(group.expressions)
 88            group.meta[FINAL] = True
 89
 90            for e in expression.selects:
 91                for node in e.walk():
 92                    if node in groups:
 93                        e.meta[FINAL] = True
 94                        break
 95
 96            having = expression.args.get("having")
 97            if having:
 98                for node in having.walk():
 99                    if node in groups:
100                        having.meta[FINAL] = True
101                        break
102
103        # Pre-order transformations
104        node = expression
105        node = rewrite_between(node)
106        node = uniq_sort(node, root)
107        node = absorb_and_eliminate(node, root)
108        node = simplify_concat(node)
109        node = simplify_conditionals(node)
110
111        if constant_propagation:
112            node = propagate_constants(node, root)
113
114        exp.replace_children(node, lambda e: _simplify(e, False))
115
116        # Post-order transformations
117        node = simplify_not(node)
118        node = flatten(node)
119        node = simplify_connectors(node, root)
120        node = remove_complements(node, root)
121        node = simplify_coalesce(node, dialect)
122        node.parent = expression.parent
123        node = simplify_literals(node, root)
124        node = simplify_equality(node)
125        node = simplify_parens(node)
126        node = simplify_datetrunc(node, dialect)
127        node = sort_comparison(node)
128        node = simplify_startswith(node)
129
130        if root:
131            expression.replace(node)
132        return node
133
134    expression = while_changing(expression, _simplify)
135    remove_where_true(expression)
136    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • max_depth: Chains of Connectors (AND, OR, etc) exceeding max_depth will be skipped
Returns:

sqlglot.Expression: simplified expression

def connector_depth(expression: sqlglot.expressions.Expression) -> int:
139def connector_depth(expression: exp.Expression) -> int:
140    """
141    Determine the maximum depth of a tree of Connectors.
142
143    For example:
144        >>> from sqlglot import parse_one
145        >>> connector_depth(parse_one("a AND b AND c AND d"))
146        3
147    """
148    stack = deque([(expression, 0)])
149    max_depth = 0
150
151    while stack:
152        expression, depth = stack.pop()
153
154        if not isinstance(expression, exp.Connector):
155            continue
156
157        depth += 1
158        max_depth = max(depth, max_depth)
159
160        stack.append((expression.left, depth))
161        stack.append((expression.right, depth))
162
163    return max_depth

Determine the maximum depth of a tree of Connectors.

For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
def catch(*exceptions):
166def catch(*exceptions):
167    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
168
169    def decorator(func):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
175
176        return wrapped
177
178    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
181def rewrite_between(expression: exp.Expression) -> exp.Expression:
182    """Rewrite x between y and z to x >= y AND x <= z.
183
184    This is done because comparison simplification is only done on lt/lte/gt/gte.
185    """
186    if isinstance(expression, exp.Between):
187        negate = isinstance(expression.parent, exp.Not)
188
189        expression = exp.and_(
190            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
191            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
192            copy=False,
193        )
194
195        if negate:
196            expression = exp.paren(expression, copy=False)
197
198    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

COMPLEMENT_SUBQUERY_PREDICATES = {<class 'sqlglot.expressions.All'>: <class 'sqlglot.expressions.Any'>, <class 'sqlglot.expressions.Any'>: <class 'sqlglot.expressions.All'>}
def simplify_not(expression):
216def simplify_not(expression):
217    """
218    Demorgan's Law
219    NOT (x OR y) -> NOT x AND NOT y
220    NOT (x AND y) -> NOT x OR NOT y
221    """
222    if isinstance(expression, exp.Not):
223        this = expression.this
224        if is_null(this):
225            return exp.null()
226        if this.__class__ in COMPLEMENT_COMPARISONS:
227            right = this.expression
228            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
229            if complement_subquery_predicate:
230                right = complement_subquery_predicate(this=right.this)
231
232            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
233        if isinstance(this, exp.Paren):
234            condition = this.unnest()
235            if isinstance(condition, exp.And):
236                return exp.paren(
237                    exp.or_(
238                        exp.not_(condition.left, copy=False),
239                        exp.not_(condition.right, copy=False),
240                        copy=False,
241                    )
242                )
243            if isinstance(condition, exp.Or):
244                return exp.paren(
245                    exp.and_(
246                        exp.not_(condition.left, copy=False),
247                        exp.not_(condition.right, copy=False),
248                        copy=False,
249                    )
250                )
251            if is_null(condition):
252                return exp.null()
253        if always_true(this):
254            return exp.false()
255        if is_false(this):
256            return exp.true()
257        if isinstance(this, exp.Not):
258            # double negation
259            # NOT NOT x -> x
260            return this.this
261    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
264def flatten(expression):
265    """
266    A AND (B AND C) -> A AND B AND C
267    A OR (B OR C) -> A OR B OR C
268    """
269    if isinstance(expression, exp.Connector):
270        for node in expression.args.values():
271            child = node.unnest()
272            if isinstance(child, expression.__class__):
273                node.replace(child)
274    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
277def simplify_connectors(expression, root=True):
278    def _simplify_connectors(expression, left, right):
279        if isinstance(expression, exp.And):
280            if is_false(left) or is_false(right):
281                return exp.false()
282            if is_zero(left) or is_zero(right):
283                return exp.false()
284            if is_null(left) or is_null(right):
285                return exp.null()
286            if always_true(left) and always_true(right):
287                return exp.true()
288            if always_true(left):
289                return right
290            if always_true(right):
291                return left
292            return _simplify_comparison(expression, left, right)
293        elif isinstance(expression, exp.Or):
294            if always_true(left) or always_true(right):
295                return exp.true()
296            if (
297                (is_null(left) and is_null(right))
298                or (is_null(left) and always_false(right))
299                or (always_false(left) and is_null(right))
300            ):
301                return exp.null()
302            if is_false(left):
303                return right
304            if is_false(right):
305                return left
306            return _simplify_comparison(expression, left, right, or_=True)
307        elif isinstance(expression, exp.Xor):
308            if left == right:
309                return exp.false()
310
311    if isinstance(expression, exp.Connector):
312        return _flat_simplify(expression, _simplify_connectors, root)
313    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
400def remove_complements(expression, root=True):
401    """
402    Removing complements.
403
404    A AND NOT A -> FALSE
405    A OR NOT A -> TRUE
406    """
407    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
408        ops = set(expression.flatten())
409        for op in ops:
410            if isinstance(op, exp.Not) and op.this in ops:
411                return exp.false() if isinstance(expression, exp.And) else exp.true()
412
413    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
416def uniq_sort(expression, root=True):
417    """
418    Uniq and sort a connector.
419
420    C AND A AND B AND B -> A AND B AND C
421    """
422    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
423        flattened = tuple(expression.flatten())
424
425        if isinstance(expression, exp.Xor):
426            result_func = exp.xor
427            # Do not deduplicate XOR as A XOR A != A if A == True
428            deduped = None
429            arr = tuple((gen(e), e) for e in flattened)
430        else:
431            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
432            deduped = {gen(e): e for e in flattened}
433            arr = tuple(deduped.items())
434
435        # check if the operands are already sorted, if not sort them
436        # A AND C AND B -> A AND B AND C
437        for i, (sql, e) in enumerate(arr[1:]):
438            if sql < arr[i][0]:
439                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
440                break
441        else:
442            # we didn't have to sort but maybe we need to dedup
443            if deduped and len(deduped) < len(flattened):
444                expression = result_func(*deduped.values(), copy=False)
445
446    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
449def absorb_and_eliminate(expression, root=True):
450    """
451    absorption:
452        A AND (A OR B) -> A
453        A OR (A AND B) -> A
454        A AND (NOT A OR B) -> A AND B
455        A OR (NOT A AND B) -> A OR B
456    elimination:
457        (A AND B) OR (A AND NOT B) -> A
458        (A OR B) AND (A OR NOT B) -> A
459    """
460    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
461        kind = exp.Or if isinstance(expression, exp.And) else exp.And
462
463        ops = tuple(expression.flatten())
464
465        # Initialize lookup tables:
466        # Set of all operands, used to find complements for absorption.
467        op_set = set()
468        # Sub-operands, used to find subsets for absorption.
469        subops = defaultdict(list)
470        # Pairs of complements, used for elimination.
471        pairs = defaultdict(list)
472
473        # Populate the lookup tables
474        for op in ops:
475            op_set.add(op)
476
477            if not isinstance(op, kind):
478                # In cases like: A OR (A AND B)
479                # Subop will be: ^
480                subops[op].append({op})
481                continue
482
483            # In cases like: (A AND B) OR (A AND B AND C)
484            # Subops will be: ^     ^
485            subset = set(op.flatten())
486            for i in subset:
487                subops[i].append(subset)
488
489            a, b = op.unnest_operands()
490            if isinstance(a, exp.Not):
491                pairs[frozenset((a.this, b))].append((op, b))
492            if isinstance(b, exp.Not):
493                pairs[frozenset((a, b.this))].append((op, a))
494
495        for op in ops:
496            if not isinstance(op, kind):
497                continue
498
499            a, b = op.unnest_operands()
500
501            # Absorb
502            if isinstance(a, exp.Not) and a.this in op_set:
503                a.replace(exp.true() if kind == exp.And else exp.false())
504                continue
505            if isinstance(b, exp.Not) and b.this in op_set:
506                b.replace(exp.true() if kind == exp.And else exp.false())
507                continue
508            superset = set(op.flatten())
509            if any(any(subset < superset for subset in subops[i]) for i in superset):
510                op.replace(exp.false() if kind == exp.And else exp.true())
511                continue
512
513            # Eliminate
514            for other, complement in pairs[frozenset((a, b))]:
515                op.replace(complement)
516                other.replace(complement)
517
518    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
521def propagate_constants(expression, root=True):
522    """
523    Propagate constants for conjunctions in DNF:
524
525    SELECT * FROM t WHERE a = b AND b = 5 becomes
526    SELECT * FROM t WHERE a = 5 AND b = 5
527
528    Reference: https://www.sqlite.org/optoverview.html
529    """
530
531    if (
532        isinstance(expression, exp.And)
533        and (root or not expression.same_parent)
534        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
535    ):
536        constant_mapping = {}
537        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
538            if isinstance(expr, exp.EQ):
539                l, r = expr.left, expr.right
540
541                # TODO: create a helper that can be used to detect nested literal expressions such
542                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
543                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
544                    constant_mapping[l] = (id(l), r)
545
546        if constant_mapping:
547            for column in find_all_in_scope(expression, exp.Column):
548                parent = column.parent
549                column_id, constant = constant_mapping.get(column) or (None, None)
550                if (
551                    column_id is not None
552                    and id(column) != column_id
553                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
554                ):
555                    column.replace(constant.copy())
556
557    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
632def simplify_literals(expression, root=True):
633    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
634        return _flat_simplify(expression, _simplify_binary, root)
635
636    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
637        return expression.this.this
638
639    if type(expression) in INVERSE_DATE_OPS:
640        return _simplify_binary(expression, expression.this, expression.interval()) or expression
641
642    return expression
def simplify_parens(expression):
743def simplify_parens(expression):
744    if not isinstance(expression, exp.Paren):
745        return expression
746
747    this = expression.this
748    parent = expression.parent
749    parent_is_predicate = isinstance(parent, exp.Predicate)
750
751    if (
752        not isinstance(this, exp.Select)
753        and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
754        and (
755            not isinstance(parent, (exp.Condition, exp.Binary))
756            or isinstance(parent, exp.Paren)
757            or (
758                not isinstance(this, exp.Binary)
759                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
760            )
761            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
762            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
763            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
764            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
765        )
766    ):
767        return this
768    return expression
def simplify_coalesce( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
779def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
780    # COALESCE(x) -> x
781    if (
782        isinstance(expression, exp.Coalesce)
783        and (not expression.expressions or _is_nonnull_constant(expression.this))
784        # COALESCE is also used as a Spark partitioning hint
785        and not isinstance(expression.parent, exp.Hint)
786    ):
787        return expression.this
788
789    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
790    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
791    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
792    if dialect == "redshift":
793        return expression
794
795    if not isinstance(expression, COMPARISONS):
796        return expression
797
798    if isinstance(expression.left, exp.Coalesce):
799        coalesce = expression.left
800        other = expression.right
801    elif isinstance(expression.right, exp.Coalesce):
802        coalesce = expression.right
803        other = expression.left
804    else:
805        return expression
806
807    # This transformation is valid for non-constants,
808    # but it really only does anything if they are both constants.
809    if not _is_constant(other):
810        return expression
811
812    # Find the first constant arg
813    for arg_index, arg in enumerate(coalesce.expressions):
814        if _is_constant(arg):
815            break
816    else:
817        return expression
818
819    coalesce.set("expressions", coalesce.expressions[:arg_index])
820
821    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
822    # since we already remove COALESCE at the top of this function.
823    coalesce = coalesce if coalesce.expressions else coalesce.this
824
825    # This expression is more complex than when we started, but it will get simplified further
826    return exp.paren(
827        exp.or_(
828            exp.and_(
829                coalesce.is_(exp.null()).not_(copy=False),
830                expression.copy(),
831                copy=False,
832            ),
833            exp.and_(
834                coalesce.is_(exp.null()),
835                type(expression)(this=arg.copy(), expression=other.copy()),
836                copy=False,
837            ),
838            copy=False,
839        )
840    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
846def simplify_concat(expression):
847    """Reduces all groups that contain string literals by concatenating them."""
848    if not isinstance(expression, CONCATS) or (
849        # We can't reduce a CONCAT_WS call if we don't statically know the separator
850        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
851    ):
852        return expression
853
854    if isinstance(expression, exp.ConcatWs):
855        sep_expr, *expressions = expression.expressions
856        sep = sep_expr.name
857        concat_type = exp.ConcatWs
858        args = {}
859    else:
860        expressions = expression.expressions
861        sep = ""
862        concat_type = exp.Concat
863        args = {
864            "safe": expression.args.get("safe"),
865            "coalesce": expression.args.get("coalesce"),
866        }
867
868    new_args = []
869    for is_string_group, group in itertools.groupby(
870        expressions or expression.flatten(), lambda e: e.is_string
871    ):
872        if is_string_group:
873            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
874        else:
875            new_args.extend(group)
876
877    if len(new_args) == 1 and new_args[0].is_string:
878        return new_args[0]
879
880    if concat_type is exp.ConcatWs:
881        new_args = [sep_expr] + new_args
882    elif isinstance(expression, exp.DPipe):
883        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
884
885    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
888def simplify_conditionals(expression):
889    """Simplifies expressions like IF, CASE if their condition is statically known."""
890    if isinstance(expression, exp.Case):
891        this = expression.this
892        for case in expression.args["ifs"]:
893            cond = case.this
894            if this:
895                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
896                cond = cond.replace(this.pop().eq(cond))
897
898            if always_true(cond):
899                return case.args["true"]
900
901            if always_false(cond):
902                case.pop()
903                if not expression.args["ifs"]:
904                    return expression.args.get("default") or exp.null()
905    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
906        if always_true(expression.this):
907            return expression.args["true"]
908        if always_false(expression.this):
909            return expression.args.get("false") or exp.null()
910
911    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
914def simplify_startswith(expression: exp.Expression) -> exp.Expression:
915    """
916    Reduces a prefix check to either TRUE or FALSE if both the string and the
917    prefix are statically known.
918
919    Example:
920        >>> from sqlglot import parse_one
921        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
922        'TRUE'
923    """
924    if (
925        isinstance(expression, exp.StartsWith)
926        and expression.this.is_string
927        and expression.expression.is_string
928    ):
929        return exp.convert(expression.name.startswith(expression.expression.name))
930
931    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.NEQ'>}
def simplify_datetrunc(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1080def sort_comparison(expression: exp.Expression) -> exp.Expression:
1081    if expression.__class__ in COMPLEMENT_COMPARISONS:
1082        l, r = expression.this, expression.expression
1083        l_column = isinstance(l, exp.Column)
1084        r_column = isinstance(r, exp.Column)
1085        l_const = _is_constant(l)
1086        r_const = _is_constant(r)
1087
1088        if (
1089            (l_column and not r_column)
1090            or (r_const and not l_const)
1091            or isinstance(r, exp.SubqueryPredicate)
1092        ):
1093            return expression
1094        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1095            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1096                this=r, expression=l
1097            )
1098    return expression
JOINS = {('RIGHT', ''), ('', 'INNER'), ('RIGHT', 'OUTER'), ('', '')}
def remove_where_true(expression):
1112def remove_where_true(expression):
1113    for where in expression.find_all(exp.Where):
1114        if always_true(where.this):
1115            where.pop()
1116    for join in expression.find_all(exp.Join):
1117        if (
1118            always_true(join.args.get("on"))
1119            and not join.args.get("using")
1120            and not join.args.get("method")
1121            and (join.side, join.kind) in JOINS
1122        ):
1123            join.args["on"].pop()
1124            join.set("side", None)
1125            join.set("kind", "CROSS")
def always_true(expression):
1128def always_true(expression):
1129    return (isinstance(expression, exp.Boolean) and expression.this) or (
1130        isinstance(expression, exp.Literal) and not is_zero(expression)
1131    )
def always_false(expression):
1134def always_false(expression):
1135    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
1138def is_zero(expression):
1139    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):
1142def is_complement(a, b):
1143    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1146def is_false(a: exp.Expression) -> bool:
1147    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1150def is_null(a: exp.Expression) -> bool:
1151    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1154def eval_boolean(expression, a, b):
1155    if isinstance(expression, (exp.EQ, exp.Is)):
1156        return boolean_literal(a == b)
1157    if isinstance(expression, exp.NEQ):
1158        return boolean_literal(a != b)
1159    if isinstance(expression, exp.GT):
1160        return boolean_literal(a > b)
1161    if isinstance(expression, exp.GTE):
1162        return boolean_literal(a >= b)
1163    if isinstance(expression, exp.LT):
1164        return boolean_literal(a < b)
1165    if isinstance(expression, exp.LTE):
1166        return boolean_literal(a <= b)
1167    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1170def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1171    if isinstance(value, datetime.datetime):
1172        return value.date()
1173    if isinstance(value, datetime.date):
1174        return value
1175    try:
1176        return datetime.datetime.fromisoformat(value).date()
1177    except ValueError:
1178        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1181def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1182    if isinstance(value, datetime.datetime):
1183        return value
1184    if isinstance(value, datetime.date):
1185        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1186    try:
1187        return datetime.datetime.fromisoformat(value)
1188    except ValueError:
1189        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1192def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1193    if not value:
1194        return None
1195    if to.is_type(exp.DataType.Type.DATE):
1196        return cast_as_date(value)
1197    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1198        return cast_as_datetime(value)
1199    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1202def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1203    if isinstance(cast, exp.Cast):
1204        to = cast.to
1205    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1206        to = exp.DataType.build(exp.DataType.Type.DATE)
1207    else:
1208        return None
1209
1210    if isinstance(cast.this, exp.Literal):
1211        value: t.Any = cast.this.name
1212    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1213        value = extract_date(cast.this)
1214    else:
1215        return None
1216    return cast_value(value, to)
def extract_interval(expression):
1223def extract_interval(expression):
1224    try:
1225        n = int(expression.this.to_py())
1226        unit = expression.text("unit").lower()
1227        return interval(unit, n)
1228    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1229        return None
def extract_type(*expressions):
1232def extract_type(*expressions):
1233    target_type = None
1234    for expression in expressions:
1235        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1236        if target_type:
1237            break
1238
1239    return target_type
def date_literal(date, target_type=None):
1242def date_literal(date, target_type=None):
1243    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1244        target_type = (
1245            exp.DataType.Type.DATETIME
1246            if isinstance(date, datetime.datetime)
1247            else exp.DataType.Type.DATE
1248        )
1249
1250    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1253def interval(unit: str, n: int = 1):
1254    from dateutil.relativedelta import relativedelta
1255
1256    if unit == "year":
1257        return relativedelta(years=1 * n)
1258    if unit == "quarter":
1259        return relativedelta(months=3 * n)
1260    if unit == "month":
1261        return relativedelta(months=1 * n)
1262    if unit == "week":
1263        return relativedelta(weeks=1 * n)
1264    if unit == "day":
1265        return relativedelta(days=1 * n)
1266    if unit == "hour":
1267        return relativedelta(hours=1 * n)
1268    if unit == "minute":
1269        return relativedelta(minutes=1 * n)
1270    if unit == "second":
1271        return relativedelta(seconds=1 * n)
1272
1273    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1276def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1277    if unit == "year":
1278        return d.replace(month=1, day=1)
1279    if unit == "quarter":
1280        if d.month <= 3:
1281            return d.replace(month=1, day=1)
1282        elif d.month <= 6:
1283            return d.replace(month=4, day=1)
1284        elif d.month <= 9:
1285            return d.replace(month=7, day=1)
1286        else:
1287            return d.replace(month=10, day=1)
1288    if unit == "month":
1289        return d.replace(month=d.month, day=1)
1290    if unit == "week":
1291        # Assuming week starts on Monday (0) and ends on Sunday (6)
1292        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1293    if unit == "day":
1294        return d
1295
1296    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1299def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1300    floor = date_floor(d, unit, dialect)
1301
1302    if floor == d:
1303        return d
1304
1305    return floor + interval(unit)
def boolean_literal(condition):
1308def boolean_literal(condition):
1309    return exp.true() if condition else exp.false()
def gen(expression: Any, comments: bool = False) -> str:
1338def gen(expression: t.Any, comments: bool = False) -> str:
1339    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1340
1341    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1342    generator is expensive so we have a bare minimum sql generator here.
1343
1344    Args:
1345        expression: the expression to convert into a SQL string.
1346        comments: whether to include the expression's comments.
1347    """
1348    return Gen().gen(expression, comments=comments)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

Arguments:
  • expression: the expression to convert into a SQL string.
  • comments: whether to include the expression's comments.
class Gen:
1351class Gen:
1352    def __init__(self):
1353        self.stack = []
1354        self.sqls = []
1355
1356    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1357        self.stack = [expression]
1358        self.sqls.clear()
1359
1360        while self.stack:
1361            node = self.stack.pop()
1362
1363            if isinstance(node, exp.Expression):
1364                if comments and node.comments:
1365                    self.stack.append(f" /*{','.join(node.comments)}*/")
1366
1367                exp_handler_name = f"{node.key}_sql"
1368
1369                if hasattr(self, exp_handler_name):
1370                    getattr(self, exp_handler_name)(node)
1371                elif isinstance(node, exp.Func):
1372                    self._function(node)
1373                else:
1374                    key = node.key.upper()
1375                    self.stack.append(f"{key} " if self._args(node) else key)
1376            elif type(node) is list:
1377                for n in reversed(node):
1378                    if n is not None:
1379                        self.stack.extend((n, ","))
1380                if node:
1381                    self.stack.pop()
1382            else:
1383                if node is not None:
1384                    self.sqls.append(str(node))
1385
1386        return "".join(self.sqls)
1387
1388    def add_sql(self, e: exp.Add) -> None:
1389        self._binary(e, " + ")
1390
1391    def alias_sql(self, e: exp.Alias) -> None:
1392        self.stack.extend(
1393            (
1394                e.args.get("alias"),
1395                " AS ",
1396                e.args.get("this"),
1397            )
1398        )
1399
1400    def and_sql(self, e: exp.And) -> None:
1401        self._binary(e, " AND ")
1402
1403    def anonymous_sql(self, e: exp.Anonymous) -> None:
1404        this = e.this
1405        if isinstance(this, str):
1406            name = this.upper()
1407        elif isinstance(this, exp.Identifier):
1408            name = this.this
1409            name = f'"{name}"' if this.quoted else name.upper()
1410        else:
1411            raise ValueError(
1412                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1413            )
1414
1415        self.stack.extend(
1416            (
1417                ")",
1418                e.expressions,
1419                "(",
1420                name,
1421            )
1422        )
1423
1424    def between_sql(self, e: exp.Between) -> None:
1425        self.stack.extend(
1426            (
1427                e.args.get("high"),
1428                " AND ",
1429                e.args.get("low"),
1430                " BETWEEN ",
1431                e.this,
1432            )
1433        )
1434
1435    def boolean_sql(self, e: exp.Boolean) -> None:
1436        self.stack.append("TRUE" if e.this else "FALSE")
1437
1438    def bracket_sql(self, e: exp.Bracket) -> None:
1439        self.stack.extend(
1440            (
1441                "]",
1442                e.expressions,
1443                "[",
1444                e.this,
1445            )
1446        )
1447
1448    def column_sql(self, e: exp.Column) -> None:
1449        for p in reversed(e.parts):
1450            self.stack.extend((p, "."))
1451        self.stack.pop()
1452
1453    def datatype_sql(self, e: exp.DataType) -> None:
1454        self._args(e, 1)
1455        self.stack.append(f"{e.this.name} ")
1456
1457    def div_sql(self, e: exp.Div) -> None:
1458        self._binary(e, " / ")
1459
1460    def dot_sql(self, e: exp.Dot) -> None:
1461        self._binary(e, ".")
1462
1463    def eq_sql(self, e: exp.EQ) -> None:
1464        self._binary(e, " = ")
1465
1466    def from_sql(self, e: exp.From) -> None:
1467        self.stack.extend((e.this, "FROM "))
1468
1469    def gt_sql(self, e: exp.GT) -> None:
1470        self._binary(e, " > ")
1471
1472    def gte_sql(self, e: exp.GTE) -> None:
1473        self._binary(e, " >= ")
1474
1475    def identifier_sql(self, e: exp.Identifier) -> None:
1476        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1477
1478    def ilike_sql(self, e: exp.ILike) -> None:
1479        self._binary(e, " ILIKE ")
1480
1481    def in_sql(self, e: exp.In) -> None:
1482        self.stack.append(")")
1483        self._args(e, 1)
1484        self.stack.extend(
1485            (
1486                "(",
1487                " IN ",
1488                e.this,
1489            )
1490        )
1491
1492    def intdiv_sql(self, e: exp.IntDiv) -> None:
1493        self._binary(e, " DIV ")
1494
1495    def is_sql(self, e: exp.Is) -> None:
1496        self._binary(e, " IS ")
1497
1498    def like_sql(self, e: exp.Like) -> None:
1499        self._binary(e, " Like ")
1500
1501    def literal_sql(self, e: exp.Literal) -> None:
1502        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1503
1504    def lt_sql(self, e: exp.LT) -> None:
1505        self._binary(e, " < ")
1506
1507    def lte_sql(self, e: exp.LTE) -> None:
1508        self._binary(e, " <= ")
1509
1510    def mod_sql(self, e: exp.Mod) -> None:
1511        self._binary(e, " % ")
1512
1513    def mul_sql(self, e: exp.Mul) -> None:
1514        self._binary(e, " * ")
1515
1516    def neg_sql(self, e: exp.Neg) -> None:
1517        self._unary(e, "-")
1518
1519    def neq_sql(self, e: exp.NEQ) -> None:
1520        self._binary(e, " <> ")
1521
1522    def not_sql(self, e: exp.Not) -> None:
1523        self._unary(e, "NOT ")
1524
1525    def null_sql(self, e: exp.Null) -> None:
1526        self.stack.append("NULL")
1527
1528    def or_sql(self, e: exp.Or) -> None:
1529        self._binary(e, " OR ")
1530
1531    def paren_sql(self, e: exp.Paren) -> None:
1532        self.stack.extend(
1533            (
1534                ")",
1535                e.this,
1536                "(",
1537            )
1538        )
1539
1540    def sub_sql(self, e: exp.Sub) -> None:
1541        self._binary(e, " - ")
1542
1543    def subquery_sql(self, e: exp.Subquery) -> None:
1544        self._args(e, 2)
1545        alias = e.args.get("alias")
1546        if alias:
1547            self.stack.append(alias)
1548        self.stack.extend((")", e.this, "("))
1549
1550    def table_sql(self, e: exp.Table) -> None:
1551        self._args(e, 4)
1552        alias = e.args.get("alias")
1553        if alias:
1554            self.stack.append(alias)
1555        for p in reversed(e.parts):
1556            self.stack.extend((p, "."))
1557        self.stack.pop()
1558
1559    def tablealias_sql(self, e: exp.TableAlias) -> None:
1560        columns = e.columns
1561
1562        if columns:
1563            self.stack.extend((")", columns, "("))
1564
1565        self.stack.extend((e.this, " AS "))
1566
1567    def var_sql(self, e: exp.Var) -> None:
1568        self.stack.append(e.this)
1569
1570    def _binary(self, e: exp.Binary, op: str) -> None:
1571        self.stack.extend((e.expression, op, e.this))
1572
1573    def _unary(self, e: exp.Unary, op: str) -> None:
1574        self.stack.extend((e.this, op))
1575
1576    def _function(self, e: exp.Func) -> None:
1577        self.stack.extend(
1578            (
1579                ")",
1580                list(e.args.values()),
1581                "(",
1582                e.sql_name(),
1583            )
1584        )
1585
1586    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1587        kvs = []
1588        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1589
1590        for k in arg_types or arg_types:
1591            v = node.args.get(k)
1592
1593            if v is not None:
1594                kvs.append([f":{k}", v])
1595        if kvs:
1596            self.stack.append(kvs)
1597            return True
1598        return False
stack
sqls
def gen( self, expression: sqlglot.expressions.Expression, comments: bool = False) -> str:
1356    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1357        self.stack = [expression]
1358        self.sqls.clear()
1359
1360        while self.stack:
1361            node = self.stack.pop()
1362
1363            if isinstance(node, exp.Expression):
1364                if comments and node.comments:
1365                    self.stack.append(f" /*{','.join(node.comments)}*/")
1366
1367                exp_handler_name = f"{node.key}_sql"
1368
1369                if hasattr(self, exp_handler_name):
1370                    getattr(self, exp_handler_name)(node)
1371                elif isinstance(node, exp.Func):
1372                    self._function(node)
1373                else:
1374                    key = node.key.upper()
1375                    self.stack.append(f"{key} " if self._args(node) else key)
1376            elif type(node) is list:
1377                for n in reversed(node):
1378                    if n is not None:
1379                        self.stack.extend((n, ","))
1380                if node:
1381                    self.stack.pop()
1382            else:
1383                if node is not None:
1384                    self.sqls.append(str(node))
1385
1386        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1388    def add_sql(self, e: exp.Add) -> None:
1389        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1391    def alias_sql(self, e: exp.Alias) -> None:
1392        self.stack.extend(
1393            (
1394                e.args.get("alias"),
1395                " AS ",
1396                e.args.get("this"),
1397            )
1398        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1400    def and_sql(self, e: exp.And) -> None:
1401        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1403    def anonymous_sql(self, e: exp.Anonymous) -> None:
1404        this = e.this
1405        if isinstance(this, str):
1406            name = this.upper()
1407        elif isinstance(this, exp.Identifier):
1408            name = this.this
1409            name = f'"{name}"' if this.quoted else name.upper()
1410        else:
1411            raise ValueError(
1412                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1413            )
1414
1415        self.stack.extend(
1416            (
1417                ")",
1418                e.expressions,
1419                "(",
1420                name,
1421            )
1422        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1424    def between_sql(self, e: exp.Between) -> None:
1425        self.stack.extend(
1426            (
1427                e.args.get("high"),
1428                " AND ",
1429                e.args.get("low"),
1430                " BETWEEN ",
1431                e.this,
1432            )
1433        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1435    def boolean_sql(self, e: exp.Boolean) -> None:
1436        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1438    def bracket_sql(self, e: exp.Bracket) -> None:
1439        self.stack.extend(
1440            (
1441                "]",
1442                e.expressions,
1443                "[",
1444                e.this,
1445            )
1446        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1448    def column_sql(self, e: exp.Column) -> None:
1449        for p in reversed(e.parts):
1450            self.stack.extend((p, "."))
1451        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1453    def datatype_sql(self, e: exp.DataType) -> None:
1454        self._args(e, 1)
1455        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1457    def div_sql(self, e: exp.Div) -> None:
1458        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1460    def dot_sql(self, e: exp.Dot) -> None:
1461        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1463    def eq_sql(self, e: exp.EQ) -> None:
1464        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1466    def from_sql(self, e: exp.From) -> None:
1467        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1469    def gt_sql(self, e: exp.GT) -> None:
1470        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1472    def gte_sql(self, e: exp.GTE) -> None:
1473        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1475    def identifier_sql(self, e: exp.Identifier) -> None:
1476        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1478    def ilike_sql(self, e: exp.ILike) -> None:
1479        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1481    def in_sql(self, e: exp.In) -> None:
1482        self.stack.append(")")
1483        self._args(e, 1)
1484        self.stack.extend(
1485            (
1486                "(",
1487                " IN ",
1488                e.this,
1489            )
1490        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1492    def intdiv_sql(self, e: exp.IntDiv) -> None:
1493        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1495    def is_sql(self, e: exp.Is) -> None:
1496        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1498    def like_sql(self, e: exp.Like) -> None:
1499        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1501    def literal_sql(self, e: exp.Literal) -> None:
1502        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1504    def lt_sql(self, e: exp.LT) -> None:
1505        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1507    def lte_sql(self, e: exp.LTE) -> None:
1508        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1510    def mod_sql(self, e: exp.Mod) -> None:
1511        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1513    def mul_sql(self, e: exp.Mul) -> None:
1514        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1516    def neg_sql(self, e: exp.Neg) -> None:
1517        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1519    def neq_sql(self, e: exp.NEQ) -> None:
1520        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1522    def not_sql(self, e: exp.Not) -> None:
1523        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1525    def null_sql(self, e: exp.Null) -> None:
1526        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1528    def or_sql(self, e: exp.Or) -> None:
1529        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1531    def paren_sql(self, e: exp.Paren) -> None:
1532        self.stack.extend(
1533            (
1534                ")",
1535                e.this,
1536                "(",
1537            )
1538        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1540    def sub_sql(self, e: exp.Sub) -> None:
1541        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1543    def subquery_sql(self, e: exp.Subquery) -> None:
1544        self._args(e, 2)
1545        alias = e.args.get("alias")
1546        if alias:
1547            self.stack.append(alias)
1548        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1550    def table_sql(self, e: exp.Table) -> None:
1551        self._args(e, 4)
1552        alias = e.args.get("alias")
1553        if alias:
1554            self.stack.append(alias)
1555        for p in reversed(e.parts):
1556            self.stack.extend((p, "."))
1557        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1559    def tablealias_sql(self, e: exp.TableAlias) -> None:
1560        columns = e.columns
1561
1562        if columns:
1563            self.stack.extend((")", columns, "("))
1564
1565        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1567    def var_sql(self, e: exp.Var) -> None:
1568        self.stack.append(e.this)