Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23logger = logging.getLogger("sqlglot")
  24
  25# Final means that an expression should not be simplified
  26FINAL = "final"
  27
  28# Value ranges for byte-sized signed/unsigned integers
  29TINYINT_MIN = -128
  30TINYINT_MAX = 127
  31UTINYINT_MIN = 0
  32UTINYINT_MAX = 255
  33
  34
  35class UnsupportedUnit(Exception):
  36    pass
  37
  38
  39def simplify(
  40    expression: exp.Expression,
  41    constant_propagation: bool = False,
  42    coalesce_simplification: bool = False,
  43    dialect: DialectType = None,
  44):
  45    """
  46    Rewrite sqlglot AST to simplify expressions.
  47
  48    Example:
  49        >>> import sqlglot
  50        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  51        >>> simplify(expression).sql()
  52        'TRUE'
  53
  54    Args:
  55        expression: expression to simplify
  56        constant_propagation: whether the constant propagation rule should be used
  57        coalesce_simplification: whether the simplify coalesce rule should be used.
  58            This rule tries to remove coalesce functions, which can be useful in certain analyses but
  59            can leave the query more verbose.
  60    Returns:
  61        sqlglot.Expression: simplified expression
  62    """
  63
  64    dialect = Dialect.get_or_raise(dialect)
  65
  66    def _simplify(expression):
  67        pre_transformation_stack = [expression]
  68        post_transformation_stack = []
  69
  70        while pre_transformation_stack:
  71            node = pre_transformation_stack.pop()
  72
  73            if node.meta.get(FINAL):
  74                continue
  75
  76            # group by expressions cannot be simplified, for example
  77            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  78            # the projection must exactly match the group by key
  79            group = node.args.get("group")
  80
  81            if group and hasattr(node, "selects"):
  82                groups = set(group.expressions)
  83                group.meta[FINAL] = True
  84
  85                for s in node.selects:
  86                    for n in s.walk():
  87                        if n in groups:
  88                            s.meta[FINAL] = True
  89                            break
  90
  91                having = node.args.get("having")
  92                if having:
  93                    for n in having.walk():
  94                        if n in groups:
  95                            having.meta[FINAL] = True
  96                            break
  97
  98            parent = node.parent
  99            root = node is expression
 100
 101            new_node = rewrite_between(node)
 102            new_node = uniq_sort(new_node, root)
 103            new_node = absorb_and_eliminate(new_node, root)
 104            new_node = simplify_concat(new_node)
 105            new_node = simplify_conditionals(new_node)
 106
 107            if constant_propagation:
 108                new_node = propagate_constants(new_node, root)
 109
 110            if new_node is not node:
 111                node.replace(new_node)
 112
 113            pre_transformation_stack.extend(
 114                n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
 115            )
 116            post_transformation_stack.append((new_node, parent))
 117
 118        while post_transformation_stack:
 119            node, parent = post_transformation_stack.pop()
 120            root = node is expression
 121
 122            # Resets parent, arg_key, index pointers– this is needed because some of the
 123            # previous transformations mutate the AST, leading to an inconsistent state
 124            for k, v in tuple(node.args.items()):
 125                node.set(k, v)
 126
 127            # Post-order transformations
 128            new_node = simplify_not(node)
 129            new_node = flatten(new_node)
 130            new_node = simplify_connectors(new_node, root)
 131            new_node = remove_complements(new_node, root)
 132
 133            if coalesce_simplification:
 134                new_node = simplify_coalesce(new_node, dialect)
 135
 136            new_node.parent = parent
 137
 138            new_node = simplify_literals(new_node, root)
 139            new_node = simplify_equality(new_node)
 140            new_node = simplify_parens(new_node)
 141            new_node = simplify_datetrunc(new_node, dialect)
 142            new_node = sort_comparison(new_node)
 143            new_node = simplify_startswith(new_node)
 144
 145            if new_node is not node:
 146                node.replace(new_node)
 147
 148        return new_node
 149
 150    expression = while_changing(expression, _simplify)
 151    remove_where_true(expression)
 152    return expression
 153
 154
 155def catch(*exceptions):
 156    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 157
 158    def decorator(func):
 159        def wrapped(expression, *args, **kwargs):
 160            try:
 161                return func(expression, *args, **kwargs)
 162            except exceptions:
 163                return expression
 164
 165        return wrapped
 166
 167    return decorator
 168
 169
 170def rewrite_between(expression: exp.Expression) -> exp.Expression:
 171    """Rewrite x between y and z to x >= y AND x <= z.
 172
 173    This is done because comparison simplification is only done on lt/lte/gt/gte.
 174    """
 175    if isinstance(expression, exp.Between):
 176        negate = isinstance(expression.parent, exp.Not)
 177
 178        expression = exp.and_(
 179            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 180            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 181            copy=False,
 182        )
 183
 184        if negate:
 185            expression = exp.paren(expression, copy=False)
 186
 187    return expression
 188
 189
 190COMPLEMENT_COMPARISONS = {
 191    exp.LT: exp.GTE,
 192    exp.GT: exp.LTE,
 193    exp.LTE: exp.GT,
 194    exp.GTE: exp.LT,
 195    exp.EQ: exp.NEQ,
 196    exp.NEQ: exp.EQ,
 197}
 198
 199COMPLEMENT_SUBQUERY_PREDICATES = {
 200    exp.All: exp.Any,
 201    exp.Any: exp.All,
 202}
 203
 204
 205def simplify_not(expression):
 206    """
 207    Demorgan's Law
 208    NOT (x OR y) -> NOT x AND NOT y
 209    NOT (x AND y) -> NOT x OR NOT y
 210    """
 211    if isinstance(expression, exp.Not):
 212        this = expression.this
 213        if is_null(this):
 214            return exp.null()
 215        if this.__class__ in COMPLEMENT_COMPARISONS:
 216            right = this.expression
 217            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
 218            if complement_subquery_predicate:
 219                right = complement_subquery_predicate(this=right.this)
 220
 221            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 222        if isinstance(this, exp.Paren):
 223            condition = this.unnest()
 224            if isinstance(condition, exp.And):
 225                return exp.paren(
 226                    exp.or_(
 227                        exp.not_(condition.left, copy=False),
 228                        exp.not_(condition.right, copy=False),
 229                        copy=False,
 230                    )
 231                )
 232            if isinstance(condition, exp.Or):
 233                return exp.paren(
 234                    exp.and_(
 235                        exp.not_(condition.left, copy=False),
 236                        exp.not_(condition.right, copy=False),
 237                        copy=False,
 238                    )
 239                )
 240            if is_null(condition):
 241                return exp.null()
 242        if always_true(this):
 243            return exp.false()
 244        if is_false(this):
 245            return exp.true()
 246        if isinstance(this, exp.Not):
 247            # double negation
 248            # NOT NOT x -> x
 249            return this.this
 250    return expression
 251
 252
 253def flatten(expression):
 254    """
 255    A AND (B AND C) -> A AND B AND C
 256    A OR (B OR C) -> A OR B OR C
 257    """
 258    if isinstance(expression, exp.Connector):
 259        for node in expression.args.values():
 260            child = node.unnest()
 261            if isinstance(child, expression.__class__):
 262                node.replace(child)
 263    return expression
 264
 265
 266def simplify_connectors(expression, root=True):
 267    def _simplify_connectors(expression, left, right):
 268        if isinstance(expression, exp.And):
 269            if is_false(left) or is_false(right):
 270                return exp.false()
 271            if is_zero(left) or is_zero(right):
 272                return exp.false()
 273            if is_null(left) or is_null(right):
 274                return exp.null()
 275            if always_true(left) and always_true(right):
 276                return exp.true()
 277            if always_true(left):
 278                return right
 279            if always_true(right):
 280                return left
 281            return _simplify_comparison(expression, left, right)
 282        elif isinstance(expression, exp.Or):
 283            if always_true(left) or always_true(right):
 284                return exp.true()
 285            if (
 286                (is_null(left) and is_null(right))
 287                or (is_null(left) and always_false(right))
 288                or (always_false(left) and is_null(right))
 289            ):
 290                return exp.null()
 291            if is_false(left):
 292                return right
 293            if is_false(right):
 294                return left
 295            return _simplify_comparison(expression, left, right, or_=True)
 296        elif isinstance(expression, exp.Xor):
 297            if left == right:
 298                return exp.false()
 299
 300    if isinstance(expression, exp.Connector):
 301        return _flat_simplify(expression, _simplify_connectors, root)
 302    return expression
 303
 304
 305LT_LTE = (exp.LT, exp.LTE)
 306GT_GTE = (exp.GT, exp.GTE)
 307
 308COMPARISONS = (
 309    *LT_LTE,
 310    *GT_GTE,
 311    exp.EQ,
 312    exp.NEQ,
 313    exp.Is,
 314)
 315
 316INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 317    exp.LT: exp.GT,
 318    exp.GT: exp.LT,
 319    exp.LTE: exp.GTE,
 320    exp.GTE: exp.LTE,
 321}
 322
 323NONDETERMINISTIC = (exp.Rand, exp.Randn)
 324AND_OR = (exp.And, exp.Or)
 325
 326
 327def _simplify_comparison(expression, left, right, or_=False):
 328    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 329        ll, lr = left.args.values()
 330        rl, rr = right.args.values()
 331
 332        largs = {ll, lr}
 333        rargs = {rl, rr}
 334
 335        matching = largs & rargs
 336        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 337
 338        if matching and columns:
 339            try:
 340                l = first(largs - columns)
 341                r = first(rargs - columns)
 342            except StopIteration:
 343                return expression
 344
 345            if l.is_number and r.is_number:
 346                l = l.to_py()
 347                r = r.to_py()
 348            elif l.is_string and r.is_string:
 349                l = l.name
 350                r = r.name
 351            else:
 352                l = extract_date(l)
 353                if not l:
 354                    return None
 355                r = extract_date(r)
 356                if not r:
 357                    return None
 358                # python won't compare date and datetime, but many engines will upcast
 359                l, r = cast_as_datetime(l), cast_as_datetime(r)
 360
 361            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 362                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 363                    return left if (av > bv if or_ else av <= bv) else right
 364                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 365                    return left if (av < bv if or_ else av >= bv) else right
 366
 367                # we can't ever shortcut to true because the column could be null
 368                if not or_:
 369                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 370                        if av <= bv:
 371                            return exp.false()
 372                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 373                        if av >= bv:
 374                            return exp.false()
 375                    elif isinstance(a, exp.EQ):
 376                        if isinstance(b, exp.LT):
 377                            return exp.false() if av >= bv else a
 378                        if isinstance(b, exp.LTE):
 379                            return exp.false() if av > bv else a
 380                        if isinstance(b, exp.GT):
 381                            return exp.false() if av <= bv else a
 382                        if isinstance(b, exp.GTE):
 383                            return exp.false() if av < bv else a
 384                        if isinstance(b, exp.NEQ):
 385                            return exp.false() if av == bv else a
 386    return None
 387
 388
 389def remove_complements(expression, root=True):
 390    """
 391    Removing complements.
 392
 393    A AND NOT A -> FALSE
 394    A OR NOT A -> TRUE
 395    """
 396    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 397        ops = set(expression.flatten())
 398        for op in ops:
 399            if isinstance(op, exp.Not) and op.this in ops:
 400                return exp.false() if isinstance(expression, exp.And) else exp.true()
 401
 402    return expression
 403
 404
 405def uniq_sort(expression, root=True):
 406    """
 407    Uniq and sort a connector.
 408
 409    C AND A AND B AND B -> A AND B AND C
 410    """
 411    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 412        flattened = tuple(expression.flatten())
 413
 414        if isinstance(expression, exp.Xor):
 415            result_func = exp.xor
 416            # Do not deduplicate XOR as A XOR A != A if A == True
 417            deduped = None
 418            arr = tuple((gen(e), e) for e in flattened)
 419        else:
 420            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 421            deduped = {gen(e): e for e in flattened}
 422            arr = tuple(deduped.items())
 423
 424        # check if the operands are already sorted, if not sort them
 425        # A AND C AND B -> A AND B AND C
 426        for i, (sql, e) in enumerate(arr[1:]):
 427            if sql < arr[i][0]:
 428                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 429                break
 430        else:
 431            # we didn't have to sort but maybe we need to dedup
 432            if deduped and len(deduped) < len(flattened):
 433                expression = result_func(*deduped.values(), copy=False)
 434
 435    return expression
 436
 437
 438def absorb_and_eliminate(expression, root=True):
 439    """
 440    absorption:
 441        A AND (A OR B) -> A
 442        A OR (A AND B) -> A
 443        A AND (NOT A OR B) -> A AND B
 444        A OR (NOT A AND B) -> A OR B
 445    elimination:
 446        (A AND B) OR (A AND NOT B) -> A
 447        (A OR B) AND (A OR NOT B) -> A
 448    """
 449    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 450        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 451
 452        ops = tuple(expression.flatten())
 453
 454        # Initialize lookup tables:
 455        # Set of all operands, used to find complements for absorption.
 456        op_set = set()
 457        # Sub-operands, used to find subsets for absorption.
 458        subops = defaultdict(list)
 459        # Pairs of complements, used for elimination.
 460        pairs = defaultdict(list)
 461
 462        # Populate the lookup tables
 463        for op in ops:
 464            op_set.add(op)
 465
 466            if not isinstance(op, kind):
 467                # In cases like: A OR (A AND B)
 468                # Subop will be: ^
 469                subops[op].append({op})
 470                continue
 471
 472            # In cases like: (A AND B) OR (A AND B AND C)
 473            # Subops will be: ^     ^
 474            subset = set(op.flatten())
 475            for i in subset:
 476                subops[i].append(subset)
 477
 478            a, b = op.unnest_operands()
 479            if isinstance(a, exp.Not):
 480                pairs[frozenset((a.this, b))].append((op, b))
 481            if isinstance(b, exp.Not):
 482                pairs[frozenset((a, b.this))].append((op, a))
 483
 484        for op in ops:
 485            if not isinstance(op, kind):
 486                continue
 487
 488            a, b = op.unnest_operands()
 489
 490            # Absorb
 491            if isinstance(a, exp.Not) and a.this in op_set:
 492                a.replace(exp.true() if kind == exp.And else exp.false())
 493                continue
 494            if isinstance(b, exp.Not) and b.this in op_set:
 495                b.replace(exp.true() if kind == exp.And else exp.false())
 496                continue
 497            superset = set(op.flatten())
 498            if any(any(subset < superset for subset in subops[i]) for i in superset):
 499                op.replace(exp.false() if kind == exp.And else exp.true())
 500                continue
 501
 502            # Eliminate
 503            for other, complement in pairs[frozenset((a, b))]:
 504                op.replace(complement)
 505                other.replace(complement)
 506
 507    return expression
 508
 509
 510def propagate_constants(expression, root=True):
 511    """
 512    Propagate constants for conjunctions in DNF:
 513
 514    SELECT * FROM t WHERE a = b AND b = 5 becomes
 515    SELECT * FROM t WHERE a = 5 AND b = 5
 516
 517    Reference: https://www.sqlite.org/optoverview.html
 518    """
 519
 520    if (
 521        isinstance(expression, exp.And)
 522        and (root or not expression.same_parent)
 523        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 524    ):
 525        constant_mapping = {}
 526        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 527            if isinstance(expr, exp.EQ):
 528                l, r = expr.left, expr.right
 529
 530                # TODO: create a helper that can be used to detect nested literal expressions such
 531                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 532                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 533                    constant_mapping[l] = (id(l), r)
 534
 535        if constant_mapping:
 536            for column in find_all_in_scope(expression, exp.Column):
 537                parent = column.parent
 538                column_id, constant = constant_mapping.get(column) or (None, None)
 539                if (
 540                    column_id is not None
 541                    and id(column) != column_id
 542                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 543                ):
 544                    column.replace(constant.copy())
 545
 546    return expression
 547
 548
 549INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 550    exp.DateAdd: exp.Sub,
 551    exp.DateSub: exp.Add,
 552    exp.DatetimeAdd: exp.Sub,
 553    exp.DatetimeSub: exp.Add,
 554}
 555
 556INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 557    **INVERSE_DATE_OPS,
 558    exp.Add: exp.Sub,
 559    exp.Sub: exp.Add,
 560}
 561
 562
 563def _is_number(expression: exp.Expression) -> bool:
 564    return expression.is_number
 565
 566
 567def _is_interval(expression: exp.Expression) -> bool:
 568    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 569
 570
 571@catch(ModuleNotFoundError, UnsupportedUnit)
 572def simplify_equality(expression: exp.Expression) -> exp.Expression:
 573    """
 574    Use the subtraction and addition properties of equality to simplify expressions:
 575
 576        x + 1 = 3 becomes x = 2
 577
 578    There are two binary operations in the above expression: + and =
 579    Here's how we reference all the operands in the code below:
 580
 581          l     r
 582        x + 1 = 3
 583        a   b
 584    """
 585    if isinstance(expression, COMPARISONS):
 586        l, r = expression.left, expression.right
 587
 588        if l.__class__ not in INVERSE_OPS:
 589            return expression
 590
 591        if r.is_number:
 592            a_predicate = _is_number
 593            b_predicate = _is_number
 594        elif _is_date_literal(r):
 595            a_predicate = _is_date_literal
 596            b_predicate = _is_interval
 597        else:
 598            return expression
 599
 600        if l.__class__ in INVERSE_DATE_OPS:
 601            l = t.cast(exp.IntervalOp, l)
 602            a = l.this
 603            b = l.interval()
 604        else:
 605            l = t.cast(exp.Binary, l)
 606            a, b = l.left, l.right
 607
 608        if not a_predicate(a) and b_predicate(b):
 609            pass
 610        elif not a_predicate(b) and b_predicate(a):
 611            a, b = b, a
 612        else:
 613            return expression
 614
 615        return expression.__class__(
 616            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 617        )
 618    return expression
 619
 620
 621def simplify_literals(expression, root=True):
 622    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 623        return _flat_simplify(expression, _simplify_binary, root)
 624
 625    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
 626        return expression.this.this
 627
 628    if type(expression) in INVERSE_DATE_OPS:
 629        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 630
 631    return expression
 632
 633
 634NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 635
 636
 637def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 638    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 639        this = _simplify_integer_cast(expr.this)
 640    else:
 641        this = expr.this
 642
 643    if isinstance(expr, exp.Cast) and this.is_int:
 644        num = this.to_py()
 645
 646        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 647        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 648        # engine-dependent
 649        if (
 650            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 651        ) or (
 652            UTINYINT_MIN <= num <= UTINYINT_MAX
 653            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 654        ):
 655            return this
 656
 657    return expr
 658
 659
 660def _simplify_binary(expression, a, b):
 661    if isinstance(expression, COMPARISONS):
 662        a = _simplify_integer_cast(a)
 663        b = _simplify_integer_cast(b)
 664
 665    if isinstance(expression, exp.Is):
 666        if isinstance(b, exp.Not):
 667            c = b.this
 668            not_ = True
 669        else:
 670            c = b
 671            not_ = False
 672
 673        if is_null(c):
 674            if isinstance(a, exp.Literal):
 675                return exp.true() if not_ else exp.false()
 676            if is_null(a):
 677                return exp.false() if not_ else exp.true()
 678    elif isinstance(expression, NULL_OK):
 679        return None
 680    elif is_null(a) or is_null(b):
 681        return exp.null()
 682
 683    if a.is_number and b.is_number:
 684        num_a = a.to_py()
 685        num_b = b.to_py()
 686
 687        if isinstance(expression, exp.Add):
 688            return exp.Literal.number(num_a + num_b)
 689        if isinstance(expression, exp.Mul):
 690            return exp.Literal.number(num_a * num_b)
 691
 692        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 693        if isinstance(expression, exp.Sub):
 694            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 695        if isinstance(expression, exp.Div):
 696            # engines have differing int div behavior so intdiv is not safe
 697            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 698                return None
 699            return exp.Literal.number(num_a / num_b)
 700
 701        boolean = eval_boolean(expression, num_a, num_b)
 702
 703        if boolean:
 704            return boolean
 705    elif a.is_string and b.is_string:
 706        boolean = eval_boolean(expression, a.this, b.this)
 707
 708        if boolean:
 709            return boolean
 710    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 711        date, b = extract_date(a), extract_interval(b)
 712        if date and b:
 713            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 714                return date_literal(date + b, extract_type(a))
 715            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 716                return date_literal(date - b, extract_type(a))
 717    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 718        a, date = extract_interval(a), extract_date(b)
 719        # you cannot subtract a date from an interval
 720        if a and b and isinstance(expression, exp.Add):
 721            return date_literal(a + date, extract_type(b))
 722    elif _is_date_literal(a) and _is_date_literal(b):
 723        if isinstance(expression, exp.Predicate):
 724            a, b = extract_date(a), extract_date(b)
 725            boolean = eval_boolean(expression, a, b)
 726            if boolean:
 727                return boolean
 728
 729    return None
 730
 731
 732def simplify_parens(expression):
 733    if not isinstance(expression, exp.Paren):
 734        return expression
 735
 736    this = expression.this
 737    parent = expression.parent
 738    parent_is_predicate = isinstance(parent, exp.Predicate)
 739
 740    if (
 741        not isinstance(this, exp.Select)
 742        and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
 743        and (
 744            not isinstance(parent, (exp.Condition, exp.Binary))
 745            or isinstance(parent, exp.Paren)
 746            or (
 747                not isinstance(this, exp.Binary)
 748                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 749            )
 750            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 751            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 752            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 753            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 754        )
 755    ):
 756        return this
 757    return expression
 758
 759
 760def _is_nonnull_constant(expression: exp.Expression) -> bool:
 761    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 762
 763
 764def _is_constant(expression: exp.Expression) -> bool:
 765    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 766
 767
 768def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
 769    # COALESCE(x) -> x
 770    if (
 771        isinstance(expression, exp.Coalesce)
 772        and (not expression.expressions or _is_nonnull_constant(expression.this))
 773        # COALESCE is also used as a Spark partitioning hint
 774        and not isinstance(expression.parent, exp.Hint)
 775    ):
 776        return expression.this
 777
 778    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
 779    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
 780    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
 781    if dialect == "redshift":
 782        return expression
 783
 784    if not isinstance(expression, COMPARISONS):
 785        return expression
 786
 787    if isinstance(expression.left, exp.Coalesce):
 788        coalesce = expression.left
 789        other = expression.right
 790    elif isinstance(expression.right, exp.Coalesce):
 791        coalesce = expression.right
 792        other = expression.left
 793    else:
 794        return expression
 795
 796    # This transformation is valid for non-constants,
 797    # but it really only does anything if they are both constants.
 798    if not _is_constant(other):
 799        return expression
 800
 801    # Find the first constant arg
 802    for arg_index, arg in enumerate(coalesce.expressions):
 803        if _is_constant(arg):
 804            break
 805    else:
 806        return expression
 807
 808    coalesce.set("expressions", coalesce.expressions[:arg_index])
 809
 810    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 811    # since we already remove COALESCE at the top of this function.
 812    coalesce = coalesce if coalesce.expressions else coalesce.this
 813
 814    # This expression is more complex than when we started, but it will get simplified further
 815    return exp.paren(
 816        exp.or_(
 817            exp.and_(
 818                coalesce.is_(exp.null()).not_(copy=False),
 819                expression.copy(),
 820                copy=False,
 821            ),
 822            exp.and_(
 823                coalesce.is_(exp.null()),
 824                type(expression)(this=arg.copy(), expression=other.copy()),
 825                copy=False,
 826            ),
 827            copy=False,
 828        )
 829    )
 830
 831
 832CONCATS = (exp.Concat, exp.DPipe)
 833
 834
 835def simplify_concat(expression):
 836    """Reduces all groups that contain string literals by concatenating them."""
 837    if not isinstance(expression, CONCATS) or (
 838        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 839        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 840    ):
 841        return expression
 842
 843    if isinstance(expression, exp.ConcatWs):
 844        sep_expr, *expressions = expression.expressions
 845        sep = sep_expr.name
 846        concat_type = exp.ConcatWs
 847        args = {}
 848    else:
 849        expressions = expression.expressions
 850        sep = ""
 851        concat_type = exp.Concat
 852        args = {
 853            "safe": expression.args.get("safe"),
 854            "coalesce": expression.args.get("coalesce"),
 855        }
 856
 857    new_args = []
 858    for is_string_group, group in itertools.groupby(
 859        expressions or expression.flatten(), lambda e: e.is_string
 860    ):
 861        if is_string_group:
 862            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 863        else:
 864            new_args.extend(group)
 865
 866    if len(new_args) == 1 and new_args[0].is_string:
 867        return new_args[0]
 868
 869    if concat_type is exp.ConcatWs:
 870        new_args = [sep_expr] + new_args
 871    elif isinstance(expression, exp.DPipe):
 872        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 873
 874    return concat_type(expressions=new_args, **args)
 875
 876
 877def simplify_conditionals(expression):
 878    """Simplifies expressions like IF, CASE if their condition is statically known."""
 879    if isinstance(expression, exp.Case):
 880        this = expression.this
 881        for case in expression.args["ifs"]:
 882            cond = case.this
 883            if this:
 884                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 885                cond = cond.replace(this.pop().eq(cond))
 886
 887            if always_true(cond):
 888                return case.args["true"]
 889
 890            if always_false(cond):
 891                case.pop()
 892                if not expression.args["ifs"]:
 893                    return expression.args.get("default") or exp.null()
 894    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 895        if always_true(expression.this):
 896            return expression.args["true"]
 897        if always_false(expression.this):
 898            return expression.args.get("false") or exp.null()
 899
 900    return expression
 901
 902
 903def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 904    """
 905    Reduces a prefix check to either TRUE or FALSE if both the string and the
 906    prefix are statically known.
 907
 908    Example:
 909        >>> from sqlglot import parse_one
 910        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 911        'TRUE'
 912    """
 913    if (
 914        isinstance(expression, exp.StartsWith)
 915        and expression.this.is_string
 916        and expression.expression.is_string
 917    ):
 918        return exp.convert(expression.name.startswith(expression.expression.name))
 919
 920    return expression
 921
 922
 923DateRange = t.Tuple[datetime.date, datetime.date]
 924
 925
 926def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 927    """
 928    Get the date range for a DATE_TRUNC equality comparison:
 929
 930    Example:
 931        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 932    Returns:
 933        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 934    """
 935    floor = date_floor(date, unit, dialect)
 936
 937    if date != floor:
 938        # This will always be False, except for NULL values.
 939        return None
 940
 941    return floor, floor + interval(unit)
 942
 943
 944def _datetrunc_eq_expression(
 945    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 946) -> exp.Expression:
 947    """Get the logical expression for a date range"""
 948    return exp.and_(
 949        left >= date_literal(drange[0], target_type),
 950        left < date_literal(drange[1], target_type),
 951        copy=False,
 952    )
 953
 954
 955def _datetrunc_eq(
 956    left: exp.Expression,
 957    date: datetime.date,
 958    unit: str,
 959    dialect: Dialect,
 960    target_type: t.Optional[exp.DataType],
 961) -> t.Optional[exp.Expression]:
 962    drange = _datetrunc_range(date, unit, dialect)
 963    if not drange:
 964        return None
 965
 966    return _datetrunc_eq_expression(left, drange, target_type)
 967
 968
 969def _datetrunc_neq(
 970    left: exp.Expression,
 971    date: datetime.date,
 972    unit: str,
 973    dialect: Dialect,
 974    target_type: t.Optional[exp.DataType],
 975) -> t.Optional[exp.Expression]:
 976    drange = _datetrunc_range(date, unit, dialect)
 977    if not drange:
 978        return None
 979
 980    return exp.and_(
 981        left < date_literal(drange[0], target_type),
 982        left >= date_literal(drange[1], target_type),
 983        copy=False,
 984    )
 985
 986
 987DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 988    exp.LT: lambda l, dt, u, d, t: l
 989    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 990    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 991    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 992    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 993    exp.EQ: _datetrunc_eq,
 994    exp.NEQ: _datetrunc_neq,
 995}
 996DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 997DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 998
 999
1000def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1001    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1002
1003
1004@catch(ModuleNotFoundError, UnsupportedUnit)
1005def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1006    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1007    comparison = expression.__class__
1008
1009    if isinstance(expression, DATETRUNCS):
1010        this = expression.this
1011        trunc_type = extract_type(this)
1012        date = extract_date(this)
1013        if date and expression.unit:
1014            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1015    elif comparison not in DATETRUNC_COMPARISONS:
1016        return expression
1017
1018    if isinstance(expression, exp.Binary):
1019        l, r = expression.left, expression.right
1020
1021        if not _is_datetrunc_predicate(l, r):
1022            return expression
1023
1024        l = t.cast(exp.DateTrunc, l)
1025        trunc_arg = l.this
1026        unit = l.unit.name.lower()
1027        date = extract_date(r)
1028
1029        if not date:
1030            return expression
1031
1032        return (
1033            DATETRUNC_BINARY_COMPARISONS[comparison](
1034                trunc_arg, date, unit, dialect, extract_type(r)
1035            )
1036            or expression
1037        )
1038
1039    if isinstance(expression, exp.In):
1040        l = expression.this
1041        rs = expression.expressions
1042
1043        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1044            l = t.cast(exp.DateTrunc, l)
1045            unit = l.unit.name.lower()
1046
1047            ranges = []
1048            for r in rs:
1049                date = extract_date(r)
1050                if not date:
1051                    return expression
1052                drange = _datetrunc_range(date, unit, dialect)
1053                if drange:
1054                    ranges.append(drange)
1055
1056            if not ranges:
1057                return expression
1058
1059            ranges = merge_ranges(ranges)
1060            target_type = extract_type(*rs)
1061
1062            return exp.or_(
1063                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1064            )
1065
1066    return expression
1067
1068
1069def sort_comparison(expression: exp.Expression) -> exp.Expression:
1070    if expression.__class__ in COMPLEMENT_COMPARISONS:
1071        l, r = expression.this, expression.expression
1072        l_column = isinstance(l, exp.Column)
1073        r_column = isinstance(r, exp.Column)
1074        l_const = _is_constant(l)
1075        r_const = _is_constant(r)
1076
1077        if (
1078            (l_column and not r_column)
1079            or (r_const and not l_const)
1080            or isinstance(r, exp.SubqueryPredicate)
1081        ):
1082            return expression
1083        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1084            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1085                this=r, expression=l
1086            )
1087    return expression
1088
1089
1090# CROSS joins result in an empty table if the right table is empty.
1091# So we can only simplify certain types of joins to CROSS.
1092# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1093JOINS = {
1094    ("", ""),
1095    ("", "INNER"),
1096    ("RIGHT", ""),
1097    ("RIGHT", "OUTER"),
1098}
1099
1100
1101def remove_where_true(expression):
1102    for where in expression.find_all(exp.Where):
1103        if always_true(where.this):
1104            where.pop()
1105    for join in expression.find_all(exp.Join):
1106        if (
1107            always_true(join.args.get("on"))
1108            and not join.args.get("using")
1109            and not join.args.get("method")
1110            and (join.side, join.kind) in JOINS
1111        ):
1112            join.args["on"].pop()
1113            join.set("side", None)
1114            join.set("kind", "CROSS")
1115
1116
1117def always_true(expression):
1118    return (isinstance(expression, exp.Boolean) and expression.this) or (
1119        isinstance(expression, exp.Literal) and not is_zero(expression)
1120    )
1121
1122
1123def always_false(expression):
1124    return is_false(expression) or is_null(expression) or is_zero(expression)
1125
1126
1127def is_zero(expression):
1128    return isinstance(expression, exp.Literal) and expression.to_py() == 0
1129
1130
1131def is_complement(a, b):
1132    return isinstance(b, exp.Not) and b.this == a
1133
1134
1135def is_false(a: exp.Expression) -> bool:
1136    return type(a) is exp.Boolean and not a.this
1137
1138
1139def is_null(a: exp.Expression) -> bool:
1140    return type(a) is exp.Null
1141
1142
1143def eval_boolean(expression, a, b):
1144    if isinstance(expression, (exp.EQ, exp.Is)):
1145        return boolean_literal(a == b)
1146    if isinstance(expression, exp.NEQ):
1147        return boolean_literal(a != b)
1148    if isinstance(expression, exp.GT):
1149        return boolean_literal(a > b)
1150    if isinstance(expression, exp.GTE):
1151        return boolean_literal(a >= b)
1152    if isinstance(expression, exp.LT):
1153        return boolean_literal(a < b)
1154    if isinstance(expression, exp.LTE):
1155        return boolean_literal(a <= b)
1156    return None
1157
1158
1159def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1160    if isinstance(value, datetime.datetime):
1161        return value.date()
1162    if isinstance(value, datetime.date):
1163        return value
1164    try:
1165        return datetime.datetime.fromisoformat(value).date()
1166    except ValueError:
1167        return None
1168
1169
1170def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1171    if isinstance(value, datetime.datetime):
1172        return value
1173    if isinstance(value, datetime.date):
1174        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1175    try:
1176        return datetime.datetime.fromisoformat(value)
1177    except ValueError:
1178        return None
1179
1180
1181def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1182    if not value:
1183        return None
1184    if to.is_type(exp.DataType.Type.DATE):
1185        return cast_as_date(value)
1186    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1187        return cast_as_datetime(value)
1188    return None
1189
1190
1191def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1192    if isinstance(cast, exp.Cast):
1193        to = cast.to
1194    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1195        to = exp.DataType.build(exp.DataType.Type.DATE)
1196    else:
1197        return None
1198
1199    if isinstance(cast.this, exp.Literal):
1200        value: t.Any = cast.this.name
1201    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1202        value = extract_date(cast.this)
1203    else:
1204        return None
1205    return cast_value(value, to)
1206
1207
1208def _is_date_literal(expression: exp.Expression) -> bool:
1209    return extract_date(expression) is not None
1210
1211
1212def extract_interval(expression):
1213    try:
1214        n = int(expression.this.to_py())
1215        unit = expression.text("unit").lower()
1216        return interval(unit, n)
1217    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1218        return None
1219
1220
1221def extract_type(*expressions):
1222    target_type = None
1223    for expression in expressions:
1224        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1225        if target_type:
1226            break
1227
1228    return target_type
1229
1230
1231def date_literal(date, target_type=None):
1232    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1233        target_type = (
1234            exp.DataType.Type.DATETIME
1235            if isinstance(date, datetime.datetime)
1236            else exp.DataType.Type.DATE
1237        )
1238
1239    return exp.cast(exp.Literal.string(date), target_type)
1240
1241
1242def interval(unit: str, n: int = 1):
1243    from dateutil.relativedelta import relativedelta
1244
1245    if unit == "year":
1246        return relativedelta(years=1 * n)
1247    if unit == "quarter":
1248        return relativedelta(months=3 * n)
1249    if unit == "month":
1250        return relativedelta(months=1 * n)
1251    if unit == "week":
1252        return relativedelta(weeks=1 * n)
1253    if unit == "day":
1254        return relativedelta(days=1 * n)
1255    if unit == "hour":
1256        return relativedelta(hours=1 * n)
1257    if unit == "minute":
1258        return relativedelta(minutes=1 * n)
1259    if unit == "second":
1260        return relativedelta(seconds=1 * n)
1261
1262    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1263
1264
1265def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1266    if unit == "year":
1267        return d.replace(month=1, day=1)
1268    if unit == "quarter":
1269        if d.month <= 3:
1270            return d.replace(month=1, day=1)
1271        elif d.month <= 6:
1272            return d.replace(month=4, day=1)
1273        elif d.month <= 9:
1274            return d.replace(month=7, day=1)
1275        else:
1276            return d.replace(month=10, day=1)
1277    if unit == "month":
1278        return d.replace(month=d.month, day=1)
1279    if unit == "week":
1280        # Assuming week starts on Monday (0) and ends on Sunday (6)
1281        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1282    if unit == "day":
1283        return d
1284
1285    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1286
1287
1288def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1289    floor = date_floor(d, unit, dialect)
1290
1291    if floor == d:
1292        return d
1293
1294    return floor + interval(unit)
1295
1296
1297def boolean_literal(condition):
1298    return exp.true() if condition else exp.false()
1299
1300
1301def _flat_simplify(expression, simplifier, root=True):
1302    if root or not expression.same_parent:
1303        operands = []
1304        queue = deque(expression.flatten(unnest=False))
1305        size = len(queue)
1306
1307        while queue:
1308            a = queue.popleft()
1309
1310            for b in queue:
1311                result = simplifier(expression, a, b)
1312
1313                if result and result is not expression:
1314                    queue.remove(b)
1315                    queue.appendleft(result)
1316                    break
1317            else:
1318                operands.append(a)
1319
1320        if len(operands) < size:
1321            return functools.reduce(
1322                lambda a, b: expression.__class__(this=a, expression=b), operands
1323            )
1324    return expression
1325
1326
1327def gen(expression: t.Any, comments: bool = False) -> str:
1328    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1329
1330    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1331    generator is expensive so we have a bare minimum sql generator here.
1332
1333    Args:
1334        expression: the expression to convert into a SQL string.
1335        comments: whether to include the expression's comments.
1336    """
1337    return Gen().gen(expression, comments=comments)
1338
1339
1340class Gen:
1341    def __init__(self):
1342        self.stack = []
1343        self.sqls = []
1344
1345    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1346        self.stack = [expression]
1347        self.sqls.clear()
1348
1349        while self.stack:
1350            node = self.stack.pop()
1351
1352            if isinstance(node, exp.Expression):
1353                if comments and node.comments:
1354                    self.stack.append(f" /*{','.join(node.comments)}*/")
1355
1356                exp_handler_name = f"{node.key}_sql"
1357
1358                if hasattr(self, exp_handler_name):
1359                    getattr(self, exp_handler_name)(node)
1360                elif isinstance(node, exp.Func):
1361                    self._function(node)
1362                else:
1363                    key = node.key.upper()
1364                    self.stack.append(f"{key} " if self._args(node) else key)
1365            elif type(node) is list:
1366                for n in reversed(node):
1367                    if n is not None:
1368                        self.stack.extend((n, ","))
1369                if node:
1370                    self.stack.pop()
1371            else:
1372                if node is not None:
1373                    self.sqls.append(str(node))
1374
1375        return "".join(self.sqls)
1376
1377    def add_sql(self, e: exp.Add) -> None:
1378        self._binary(e, " + ")
1379
1380    def alias_sql(self, e: exp.Alias) -> None:
1381        self.stack.extend(
1382            (
1383                e.args.get("alias"),
1384                " AS ",
1385                e.args.get("this"),
1386            )
1387        )
1388
1389    def and_sql(self, e: exp.And) -> None:
1390        self._binary(e, " AND ")
1391
1392    def anonymous_sql(self, e: exp.Anonymous) -> None:
1393        this = e.this
1394        if isinstance(this, str):
1395            name = this.upper()
1396        elif isinstance(this, exp.Identifier):
1397            name = this.this
1398            name = f'"{name}"' if this.quoted else name.upper()
1399        else:
1400            raise ValueError(
1401                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1402            )
1403
1404        self.stack.extend(
1405            (
1406                ")",
1407                e.expressions,
1408                "(",
1409                name,
1410            )
1411        )
1412
1413    def between_sql(self, e: exp.Between) -> None:
1414        self.stack.extend(
1415            (
1416                e.args.get("high"),
1417                " AND ",
1418                e.args.get("low"),
1419                " BETWEEN ",
1420                e.this,
1421            )
1422        )
1423
1424    def boolean_sql(self, e: exp.Boolean) -> None:
1425        self.stack.append("TRUE" if e.this else "FALSE")
1426
1427    def bracket_sql(self, e: exp.Bracket) -> None:
1428        self.stack.extend(
1429            (
1430                "]",
1431                e.expressions,
1432                "[",
1433                e.this,
1434            )
1435        )
1436
1437    def column_sql(self, e: exp.Column) -> None:
1438        for p in reversed(e.parts):
1439            self.stack.extend((p, "."))
1440        self.stack.pop()
1441
1442    def datatype_sql(self, e: exp.DataType) -> None:
1443        self._args(e, 1)
1444        self.stack.append(f"{e.this.name} ")
1445
1446    def div_sql(self, e: exp.Div) -> None:
1447        self._binary(e, " / ")
1448
1449    def dot_sql(self, e: exp.Dot) -> None:
1450        self._binary(e, ".")
1451
1452    def eq_sql(self, e: exp.EQ) -> None:
1453        self._binary(e, " = ")
1454
1455    def from_sql(self, e: exp.From) -> None:
1456        self.stack.extend((e.this, "FROM "))
1457
1458    def gt_sql(self, e: exp.GT) -> None:
1459        self._binary(e, " > ")
1460
1461    def gte_sql(self, e: exp.GTE) -> None:
1462        self._binary(e, " >= ")
1463
1464    def identifier_sql(self, e: exp.Identifier) -> None:
1465        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1466
1467    def ilike_sql(self, e: exp.ILike) -> None:
1468        self._binary(e, " ILIKE ")
1469
1470    def in_sql(self, e: exp.In) -> None:
1471        self.stack.append(")")
1472        self._args(e, 1)
1473        self.stack.extend(
1474            (
1475                "(",
1476                " IN ",
1477                e.this,
1478            )
1479        )
1480
1481    def intdiv_sql(self, e: exp.IntDiv) -> None:
1482        self._binary(e, " DIV ")
1483
1484    def is_sql(self, e: exp.Is) -> None:
1485        self._binary(e, " IS ")
1486
1487    def like_sql(self, e: exp.Like) -> None:
1488        self._binary(e, " Like ")
1489
1490    def literal_sql(self, e: exp.Literal) -> None:
1491        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1492
1493    def lt_sql(self, e: exp.LT) -> None:
1494        self._binary(e, " < ")
1495
1496    def lte_sql(self, e: exp.LTE) -> None:
1497        self._binary(e, " <= ")
1498
1499    def mod_sql(self, e: exp.Mod) -> None:
1500        self._binary(e, " % ")
1501
1502    def mul_sql(self, e: exp.Mul) -> None:
1503        self._binary(e, " * ")
1504
1505    def neg_sql(self, e: exp.Neg) -> None:
1506        self._unary(e, "-")
1507
1508    def neq_sql(self, e: exp.NEQ) -> None:
1509        self._binary(e, " <> ")
1510
1511    def not_sql(self, e: exp.Not) -> None:
1512        self._unary(e, "NOT ")
1513
1514    def null_sql(self, e: exp.Null) -> None:
1515        self.stack.append("NULL")
1516
1517    def or_sql(self, e: exp.Or) -> None:
1518        self._binary(e, " OR ")
1519
1520    def paren_sql(self, e: exp.Paren) -> None:
1521        self.stack.extend(
1522            (
1523                ")",
1524                e.this,
1525                "(",
1526            )
1527        )
1528
1529    def sub_sql(self, e: exp.Sub) -> None:
1530        self._binary(e, " - ")
1531
1532    def subquery_sql(self, e: exp.Subquery) -> None:
1533        self._args(e, 2)
1534        alias = e.args.get("alias")
1535        if alias:
1536            self.stack.append(alias)
1537        self.stack.extend((")", e.this, "("))
1538
1539    def table_sql(self, e: exp.Table) -> None:
1540        self._args(e, 4)
1541        alias = e.args.get("alias")
1542        if alias:
1543            self.stack.append(alias)
1544        for p in reversed(e.parts):
1545            self.stack.extend((p, "."))
1546        self.stack.pop()
1547
1548    def tablealias_sql(self, e: exp.TableAlias) -> None:
1549        columns = e.columns
1550
1551        if columns:
1552            self.stack.extend((")", columns, "("))
1553
1554        self.stack.extend((e.this, " AS "))
1555
1556    def var_sql(self, e: exp.Var) -> None:
1557        self.stack.append(e.this)
1558
1559    def _binary(self, e: exp.Binary, op: str) -> None:
1560        self.stack.extend((e.expression, op, e.this))
1561
1562    def _unary(self, e: exp.Unary, op: str) -> None:
1563        self.stack.extend((e.this, op))
1564
1565    def _function(self, e: exp.Func) -> None:
1566        self.stack.extend(
1567            (
1568                ")",
1569                list(e.args.values()),
1570                "(",
1571                e.sql_name(),
1572            )
1573        )
1574
1575    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1576        kvs = []
1577        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1578
1579        for k in arg_types or arg_types:
1580            v = node.args.get(k)
1581
1582            if v is not None:
1583                kvs.append([f":{k}", v])
1584        if kvs:
1585            self.stack.append(kvs)
1586            return True
1587        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
36class UnsupportedUnit(Exception):
37    pass

Common base class for all non-exit exceptions.

def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, coalesce_simplification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None):
 40def simplify(
 41    expression: exp.Expression,
 42    constant_propagation: bool = False,
 43    coalesce_simplification: bool = False,
 44    dialect: DialectType = None,
 45):
 46    """
 47    Rewrite sqlglot AST to simplify expressions.
 48
 49    Example:
 50        >>> import sqlglot
 51        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 52        >>> simplify(expression).sql()
 53        'TRUE'
 54
 55    Args:
 56        expression: expression to simplify
 57        constant_propagation: whether the constant propagation rule should be used
 58        coalesce_simplification: whether the simplify coalesce rule should be used.
 59            This rule tries to remove coalesce functions, which can be useful in certain analyses but
 60            can leave the query more verbose.
 61    Returns:
 62        sqlglot.Expression: simplified expression
 63    """
 64
 65    dialect = Dialect.get_or_raise(dialect)
 66
 67    def _simplify(expression):
 68        pre_transformation_stack = [expression]
 69        post_transformation_stack = []
 70
 71        while pre_transformation_stack:
 72            node = pre_transformation_stack.pop()
 73
 74            if node.meta.get(FINAL):
 75                continue
 76
 77            # group by expressions cannot be simplified, for example
 78            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 79            # the projection must exactly match the group by key
 80            group = node.args.get("group")
 81
 82            if group and hasattr(node, "selects"):
 83                groups = set(group.expressions)
 84                group.meta[FINAL] = True
 85
 86                for s in node.selects:
 87                    for n in s.walk():
 88                        if n in groups:
 89                            s.meta[FINAL] = True
 90                            break
 91
 92                having = node.args.get("having")
 93                if having:
 94                    for n in having.walk():
 95                        if n in groups:
 96                            having.meta[FINAL] = True
 97                            break
 98
 99            parent = node.parent
100            root = node is expression
101
102            new_node = rewrite_between(node)
103            new_node = uniq_sort(new_node, root)
104            new_node = absorb_and_eliminate(new_node, root)
105            new_node = simplify_concat(new_node)
106            new_node = simplify_conditionals(new_node)
107
108            if constant_propagation:
109                new_node = propagate_constants(new_node, root)
110
111            if new_node is not node:
112                node.replace(new_node)
113
114            pre_transformation_stack.extend(
115                n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
116            )
117            post_transformation_stack.append((new_node, parent))
118
119        while post_transformation_stack:
120            node, parent = post_transformation_stack.pop()
121            root = node is expression
122
123            # Resets parent, arg_key, index pointers– this is needed because some of the
124            # previous transformations mutate the AST, leading to an inconsistent state
125            for k, v in tuple(node.args.items()):
126                node.set(k, v)
127
128            # Post-order transformations
129            new_node = simplify_not(node)
130            new_node = flatten(new_node)
131            new_node = simplify_connectors(new_node, root)
132            new_node = remove_complements(new_node, root)
133
134            if coalesce_simplification:
135                new_node = simplify_coalesce(new_node, dialect)
136
137            new_node.parent = parent
138
139            new_node = simplify_literals(new_node, root)
140            new_node = simplify_equality(new_node)
141            new_node = simplify_parens(new_node)
142            new_node = simplify_datetrunc(new_node, dialect)
143            new_node = sort_comparison(new_node)
144            new_node = simplify_startswith(new_node)
145
146            if new_node is not node:
147                node.replace(new_node)
148
149        return new_node
150
151    expression = while_changing(expression, _simplify)
152    remove_where_true(expression)
153    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
156def catch(*exceptions):
157    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
158
159    def decorator(func):
160        def wrapped(expression, *args, **kwargs):
161            try:
162                return func(expression, *args, **kwargs)
163            except exceptions:
164                return expression
165
166        return wrapped
167
168    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
171def rewrite_between(expression: exp.Expression) -> exp.Expression:
172    """Rewrite x between y and z to x >= y AND x <= z.
173
174    This is done because comparison simplification is only done on lt/lte/gt/gte.
175    """
176    if isinstance(expression, exp.Between):
177        negate = isinstance(expression.parent, exp.Not)
178
179        expression = exp.and_(
180            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
181            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
182            copy=False,
183        )
184
185        if negate:
186            expression = exp.paren(expression, copy=False)
187
188    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

COMPLEMENT_SUBQUERY_PREDICATES = {<class 'sqlglot.expressions.All'>: <class 'sqlglot.expressions.Any'>, <class 'sqlglot.expressions.Any'>: <class 'sqlglot.expressions.All'>}
def simplify_not(expression):
206def simplify_not(expression):
207    """
208    Demorgan's Law
209    NOT (x OR y) -> NOT x AND NOT y
210    NOT (x AND y) -> NOT x OR NOT y
211    """
212    if isinstance(expression, exp.Not):
213        this = expression.this
214        if is_null(this):
215            return exp.null()
216        if this.__class__ in COMPLEMENT_COMPARISONS:
217            right = this.expression
218            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
219            if complement_subquery_predicate:
220                right = complement_subquery_predicate(this=right.this)
221
222            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
223        if isinstance(this, exp.Paren):
224            condition = this.unnest()
225            if isinstance(condition, exp.And):
226                return exp.paren(
227                    exp.or_(
228                        exp.not_(condition.left, copy=False),
229                        exp.not_(condition.right, copy=False),
230                        copy=False,
231                    )
232                )
233            if isinstance(condition, exp.Or):
234                return exp.paren(
235                    exp.and_(
236                        exp.not_(condition.left, copy=False),
237                        exp.not_(condition.right, copy=False),
238                        copy=False,
239                    )
240                )
241            if is_null(condition):
242                return exp.null()
243        if always_true(this):
244            return exp.false()
245        if is_false(this):
246            return exp.true()
247        if isinstance(this, exp.Not):
248            # double negation
249            # NOT NOT x -> x
250            return this.this
251    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
254def flatten(expression):
255    """
256    A AND (B AND C) -> A AND B AND C
257    A OR (B OR C) -> A OR B OR C
258    """
259    if isinstance(expression, exp.Connector):
260        for node in expression.args.values():
261            child = node.unnest()
262            if isinstance(child, expression.__class__):
263                node.replace(child)
264    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
267def simplify_connectors(expression, root=True):
268    def _simplify_connectors(expression, left, right):
269        if isinstance(expression, exp.And):
270            if is_false(left) or is_false(right):
271                return exp.false()
272            if is_zero(left) or is_zero(right):
273                return exp.false()
274            if is_null(left) or is_null(right):
275                return exp.null()
276            if always_true(left) and always_true(right):
277                return exp.true()
278            if always_true(left):
279                return right
280            if always_true(right):
281                return left
282            return _simplify_comparison(expression, left, right)
283        elif isinstance(expression, exp.Or):
284            if always_true(left) or always_true(right):
285                return exp.true()
286            if (
287                (is_null(left) and is_null(right))
288                or (is_null(left) and always_false(right))
289                or (always_false(left) and is_null(right))
290            ):
291                return exp.null()
292            if is_false(left):
293                return right
294            if is_false(right):
295                return left
296            return _simplify_comparison(expression, left, right, or_=True)
297        elif isinstance(expression, exp.Xor):
298            if left == right:
299                return exp.false()
300
301    if isinstance(expression, exp.Connector):
302        return _flat_simplify(expression, _simplify_connectors, root)
303    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
390def remove_complements(expression, root=True):
391    """
392    Removing complements.
393
394    A AND NOT A -> FALSE
395    A OR NOT A -> TRUE
396    """
397    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
398        ops = set(expression.flatten())
399        for op in ops:
400            if isinstance(op, exp.Not) and op.this in ops:
401                return exp.false() if isinstance(expression, exp.And) else exp.true()
402
403    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
406def uniq_sort(expression, root=True):
407    """
408    Uniq and sort a connector.
409
410    C AND A AND B AND B -> A AND B AND C
411    """
412    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
413        flattened = tuple(expression.flatten())
414
415        if isinstance(expression, exp.Xor):
416            result_func = exp.xor
417            # Do not deduplicate XOR as A XOR A != A if A == True
418            deduped = None
419            arr = tuple((gen(e), e) for e in flattened)
420        else:
421            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
422            deduped = {gen(e): e for e in flattened}
423            arr = tuple(deduped.items())
424
425        # check if the operands are already sorted, if not sort them
426        # A AND C AND B -> A AND B AND C
427        for i, (sql, e) in enumerate(arr[1:]):
428            if sql < arr[i][0]:
429                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
430                break
431        else:
432            # we didn't have to sort but maybe we need to dedup
433            if deduped and len(deduped) < len(flattened):
434                expression = result_func(*deduped.values(), copy=False)
435
436    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
439def absorb_and_eliminate(expression, root=True):
440    """
441    absorption:
442        A AND (A OR B) -> A
443        A OR (A AND B) -> A
444        A AND (NOT A OR B) -> A AND B
445        A OR (NOT A AND B) -> A OR B
446    elimination:
447        (A AND B) OR (A AND NOT B) -> A
448        (A OR B) AND (A OR NOT B) -> A
449    """
450    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
451        kind = exp.Or if isinstance(expression, exp.And) else exp.And
452
453        ops = tuple(expression.flatten())
454
455        # Initialize lookup tables:
456        # Set of all operands, used to find complements for absorption.
457        op_set = set()
458        # Sub-operands, used to find subsets for absorption.
459        subops = defaultdict(list)
460        # Pairs of complements, used for elimination.
461        pairs = defaultdict(list)
462
463        # Populate the lookup tables
464        for op in ops:
465            op_set.add(op)
466
467            if not isinstance(op, kind):
468                # In cases like: A OR (A AND B)
469                # Subop will be: ^
470                subops[op].append({op})
471                continue
472
473            # In cases like: (A AND B) OR (A AND B AND C)
474            # Subops will be: ^     ^
475            subset = set(op.flatten())
476            for i in subset:
477                subops[i].append(subset)
478
479            a, b = op.unnest_operands()
480            if isinstance(a, exp.Not):
481                pairs[frozenset((a.this, b))].append((op, b))
482            if isinstance(b, exp.Not):
483                pairs[frozenset((a, b.this))].append((op, a))
484
485        for op in ops:
486            if not isinstance(op, kind):
487                continue
488
489            a, b = op.unnest_operands()
490
491            # Absorb
492            if isinstance(a, exp.Not) and a.this in op_set:
493                a.replace(exp.true() if kind == exp.And else exp.false())
494                continue
495            if isinstance(b, exp.Not) and b.this in op_set:
496                b.replace(exp.true() if kind == exp.And else exp.false())
497                continue
498            superset = set(op.flatten())
499            if any(any(subset < superset for subset in subops[i]) for i in superset):
500                op.replace(exp.false() if kind == exp.And else exp.true())
501                continue
502
503            # Eliminate
504            for other, complement in pairs[frozenset((a, b))]:
505                op.replace(complement)
506                other.replace(complement)
507
508    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
511def propagate_constants(expression, root=True):
512    """
513    Propagate constants for conjunctions in DNF:
514
515    SELECT * FROM t WHERE a = b AND b = 5 becomes
516    SELECT * FROM t WHERE a = 5 AND b = 5
517
518    Reference: https://www.sqlite.org/optoverview.html
519    """
520
521    if (
522        isinstance(expression, exp.And)
523        and (root or not expression.same_parent)
524        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
525    ):
526        constant_mapping = {}
527        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
528            if isinstance(expr, exp.EQ):
529                l, r = expr.left, expr.right
530
531                # TODO: create a helper that can be used to detect nested literal expressions such
532                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
533                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
534                    constant_mapping[l] = (id(l), r)
535
536        if constant_mapping:
537            for column in find_all_in_scope(expression, exp.Column):
538                parent = column.parent
539                column_id, constant = constant_mapping.get(column) or (None, None)
540                if (
541                    column_id is not None
542                    and id(column) != column_id
543                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
544                ):
545                    column.replace(constant.copy())
546
547    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
160        def wrapped(expression, *args, **kwargs):
161            try:
162                return func(expression, *args, **kwargs)
163            except exceptions:
164                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
622def simplify_literals(expression, root=True):
623    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
624        return _flat_simplify(expression, _simplify_binary, root)
625
626    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
627        return expression.this.this
628
629    if type(expression) in INVERSE_DATE_OPS:
630        return _simplify_binary(expression, expression.this, expression.interval()) or expression
631
632    return expression
def simplify_parens(expression):
733def simplify_parens(expression):
734    if not isinstance(expression, exp.Paren):
735        return expression
736
737    this = expression.this
738    parent = expression.parent
739    parent_is_predicate = isinstance(parent, exp.Predicate)
740
741    if (
742        not isinstance(this, exp.Select)
743        and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
744        and (
745            not isinstance(parent, (exp.Condition, exp.Binary))
746            or isinstance(parent, exp.Paren)
747            or (
748                not isinstance(this, exp.Binary)
749                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
750            )
751            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
752            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
753            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
754            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
755        )
756    ):
757        return this
758    return expression
def simplify_coalesce( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
769def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
770    # COALESCE(x) -> x
771    if (
772        isinstance(expression, exp.Coalesce)
773        and (not expression.expressions or _is_nonnull_constant(expression.this))
774        # COALESCE is also used as a Spark partitioning hint
775        and not isinstance(expression.parent, exp.Hint)
776    ):
777        return expression.this
778
779    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
780    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
781    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
782    if dialect == "redshift":
783        return expression
784
785    if not isinstance(expression, COMPARISONS):
786        return expression
787
788    if isinstance(expression.left, exp.Coalesce):
789        coalesce = expression.left
790        other = expression.right
791    elif isinstance(expression.right, exp.Coalesce):
792        coalesce = expression.right
793        other = expression.left
794    else:
795        return expression
796
797    # This transformation is valid for non-constants,
798    # but it really only does anything if they are both constants.
799    if not _is_constant(other):
800        return expression
801
802    # Find the first constant arg
803    for arg_index, arg in enumerate(coalesce.expressions):
804        if _is_constant(arg):
805            break
806    else:
807        return expression
808
809    coalesce.set("expressions", coalesce.expressions[:arg_index])
810
811    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
812    # since we already remove COALESCE at the top of this function.
813    coalesce = coalesce if coalesce.expressions else coalesce.this
814
815    # This expression is more complex than when we started, but it will get simplified further
816    return exp.paren(
817        exp.or_(
818            exp.and_(
819                coalesce.is_(exp.null()).not_(copy=False),
820                expression.copy(),
821                copy=False,
822            ),
823            exp.and_(
824                coalesce.is_(exp.null()),
825                type(expression)(this=arg.copy(), expression=other.copy()),
826                copy=False,
827            ),
828            copy=False,
829        )
830    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
836def simplify_concat(expression):
837    """Reduces all groups that contain string literals by concatenating them."""
838    if not isinstance(expression, CONCATS) or (
839        # We can't reduce a CONCAT_WS call if we don't statically know the separator
840        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
841    ):
842        return expression
843
844    if isinstance(expression, exp.ConcatWs):
845        sep_expr, *expressions = expression.expressions
846        sep = sep_expr.name
847        concat_type = exp.ConcatWs
848        args = {}
849    else:
850        expressions = expression.expressions
851        sep = ""
852        concat_type = exp.Concat
853        args = {
854            "safe": expression.args.get("safe"),
855            "coalesce": expression.args.get("coalesce"),
856        }
857
858    new_args = []
859    for is_string_group, group in itertools.groupby(
860        expressions or expression.flatten(), lambda e: e.is_string
861    ):
862        if is_string_group:
863            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
864        else:
865            new_args.extend(group)
866
867    if len(new_args) == 1 and new_args[0].is_string:
868        return new_args[0]
869
870    if concat_type is exp.ConcatWs:
871        new_args = [sep_expr] + new_args
872    elif isinstance(expression, exp.DPipe):
873        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
874
875    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
878def simplify_conditionals(expression):
879    """Simplifies expressions like IF, CASE if their condition is statically known."""
880    if isinstance(expression, exp.Case):
881        this = expression.this
882        for case in expression.args["ifs"]:
883            cond = case.this
884            if this:
885                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
886                cond = cond.replace(this.pop().eq(cond))
887
888            if always_true(cond):
889                return case.args["true"]
890
891            if always_false(cond):
892                case.pop()
893                if not expression.args["ifs"]:
894                    return expression.args.get("default") or exp.null()
895    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
896        if always_true(expression.this):
897            return expression.args["true"]
898        if always_false(expression.this):
899            return expression.args.get("false") or exp.null()
900
901    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
904def simplify_startswith(expression: exp.Expression) -> exp.Expression:
905    """
906    Reduces a prefix check to either TRUE or FALSE if both the string and the
907    prefix are statically known.
908
909    Example:
910        >>> from sqlglot import parse_one
911        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
912        'TRUE'
913    """
914    if (
915        isinstance(expression, exp.StartsWith)
916        and expression.this.is_string
917        and expression.expression.is_string
918    ):
919        return exp.convert(expression.name.startswith(expression.expression.name))
920
921    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc(expression, *args, **kwargs):
160        def wrapped(expression, *args, **kwargs):
161            try:
162                return func(expression, *args, **kwargs)
163            except exceptions:
164                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1070def sort_comparison(expression: exp.Expression) -> exp.Expression:
1071    if expression.__class__ in COMPLEMENT_COMPARISONS:
1072        l, r = expression.this, expression.expression
1073        l_column = isinstance(l, exp.Column)
1074        r_column = isinstance(r, exp.Column)
1075        l_const = _is_constant(l)
1076        r_const = _is_constant(r)
1077
1078        if (
1079            (l_column and not r_column)
1080            or (r_const and not l_const)
1081            or isinstance(r, exp.SubqueryPredicate)
1082        ):
1083            return expression
1084        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1085            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1086                this=r, expression=l
1087            )
1088    return expression
JOINS = {('RIGHT', ''), ('RIGHT', 'OUTER'), ('', 'INNER'), ('', '')}
def remove_where_true(expression):
1102def remove_where_true(expression):
1103    for where in expression.find_all(exp.Where):
1104        if always_true(where.this):
1105            where.pop()
1106    for join in expression.find_all(exp.Join):
1107        if (
1108            always_true(join.args.get("on"))
1109            and not join.args.get("using")
1110            and not join.args.get("method")
1111            and (join.side, join.kind) in JOINS
1112        ):
1113            join.args["on"].pop()
1114            join.set("side", None)
1115            join.set("kind", "CROSS")
def always_true(expression):
1118def always_true(expression):
1119    return (isinstance(expression, exp.Boolean) and expression.this) or (
1120        isinstance(expression, exp.Literal) and not is_zero(expression)
1121    )
def always_false(expression):
1124def always_false(expression):
1125    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
1128def is_zero(expression):
1129    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):
1132def is_complement(a, b):
1133    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1136def is_false(a: exp.Expression) -> bool:
1137    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1140def is_null(a: exp.Expression) -> bool:
1141    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1144def eval_boolean(expression, a, b):
1145    if isinstance(expression, (exp.EQ, exp.Is)):
1146        return boolean_literal(a == b)
1147    if isinstance(expression, exp.NEQ):
1148        return boolean_literal(a != b)
1149    if isinstance(expression, exp.GT):
1150        return boolean_literal(a > b)
1151    if isinstance(expression, exp.GTE):
1152        return boolean_literal(a >= b)
1153    if isinstance(expression, exp.LT):
1154        return boolean_literal(a < b)
1155    if isinstance(expression, exp.LTE):
1156        return boolean_literal(a <= b)
1157    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1160def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1161    if isinstance(value, datetime.datetime):
1162        return value.date()
1163    if isinstance(value, datetime.date):
1164        return value
1165    try:
1166        return datetime.datetime.fromisoformat(value).date()
1167    except ValueError:
1168        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1171def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1172    if isinstance(value, datetime.datetime):
1173        return value
1174    if isinstance(value, datetime.date):
1175        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1176    try:
1177        return datetime.datetime.fromisoformat(value)
1178    except ValueError:
1179        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1182def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1183    if not value:
1184        return None
1185    if to.is_type(exp.DataType.Type.DATE):
1186        return cast_as_date(value)
1187    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1188        return cast_as_datetime(value)
1189    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1192def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1193    if isinstance(cast, exp.Cast):
1194        to = cast.to
1195    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1196        to = exp.DataType.build(exp.DataType.Type.DATE)
1197    else:
1198        return None
1199
1200    if isinstance(cast.this, exp.Literal):
1201        value: t.Any = cast.this.name
1202    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1203        value = extract_date(cast.this)
1204    else:
1205        return None
1206    return cast_value(value, to)
def extract_interval(expression):
1213def extract_interval(expression):
1214    try:
1215        n = int(expression.this.to_py())
1216        unit = expression.text("unit").lower()
1217        return interval(unit, n)
1218    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1219        return None
def extract_type(*expressions):
1222def extract_type(*expressions):
1223    target_type = None
1224    for expression in expressions:
1225        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1226        if target_type:
1227            break
1228
1229    return target_type
def date_literal(date, target_type=None):
1232def date_literal(date, target_type=None):
1233    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1234        target_type = (
1235            exp.DataType.Type.DATETIME
1236            if isinstance(date, datetime.datetime)
1237            else exp.DataType.Type.DATE
1238        )
1239
1240    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1243def interval(unit: str, n: int = 1):
1244    from dateutil.relativedelta import relativedelta
1245
1246    if unit == "year":
1247        return relativedelta(years=1 * n)
1248    if unit == "quarter":
1249        return relativedelta(months=3 * n)
1250    if unit == "month":
1251        return relativedelta(months=1 * n)
1252    if unit == "week":
1253        return relativedelta(weeks=1 * n)
1254    if unit == "day":
1255        return relativedelta(days=1 * n)
1256    if unit == "hour":
1257        return relativedelta(hours=1 * n)
1258    if unit == "minute":
1259        return relativedelta(minutes=1 * n)
1260    if unit == "second":
1261        return relativedelta(seconds=1 * n)
1262
1263    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1266def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1267    if unit == "year":
1268        return d.replace(month=1, day=1)
1269    if unit == "quarter":
1270        if d.month <= 3:
1271            return d.replace(month=1, day=1)
1272        elif d.month <= 6:
1273            return d.replace(month=4, day=1)
1274        elif d.month <= 9:
1275            return d.replace(month=7, day=1)
1276        else:
1277            return d.replace(month=10, day=1)
1278    if unit == "month":
1279        return d.replace(month=d.month, day=1)
1280    if unit == "week":
1281        # Assuming week starts on Monday (0) and ends on Sunday (6)
1282        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1283    if unit == "day":
1284        return d
1285
1286    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1289def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1290    floor = date_floor(d, unit, dialect)
1291
1292    if floor == d:
1293        return d
1294
1295    return floor + interval(unit)
def boolean_literal(condition):
1298def boolean_literal(condition):
1299    return exp.true() if condition else exp.false()
def gen(expression: Any, comments: bool = False) -> str:
1328def gen(expression: t.Any, comments: bool = False) -> str:
1329    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1330
1331    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1332    generator is expensive so we have a bare minimum sql generator here.
1333
1334    Args:
1335        expression: the expression to convert into a SQL string.
1336        comments: whether to include the expression's comments.
1337    """
1338    return Gen().gen(expression, comments=comments)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

Arguments:
  • expression: the expression to convert into a SQL string.
  • comments: whether to include the expression's comments.
class Gen:
1341class Gen:
1342    def __init__(self):
1343        self.stack = []
1344        self.sqls = []
1345
1346    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1347        self.stack = [expression]
1348        self.sqls.clear()
1349
1350        while self.stack:
1351            node = self.stack.pop()
1352
1353            if isinstance(node, exp.Expression):
1354                if comments and node.comments:
1355                    self.stack.append(f" /*{','.join(node.comments)}*/")
1356
1357                exp_handler_name = f"{node.key}_sql"
1358
1359                if hasattr(self, exp_handler_name):
1360                    getattr(self, exp_handler_name)(node)
1361                elif isinstance(node, exp.Func):
1362                    self._function(node)
1363                else:
1364                    key = node.key.upper()
1365                    self.stack.append(f"{key} " if self._args(node) else key)
1366            elif type(node) is list:
1367                for n in reversed(node):
1368                    if n is not None:
1369                        self.stack.extend((n, ","))
1370                if node:
1371                    self.stack.pop()
1372            else:
1373                if node is not None:
1374                    self.sqls.append(str(node))
1375
1376        return "".join(self.sqls)
1377
1378    def add_sql(self, e: exp.Add) -> None:
1379        self._binary(e, " + ")
1380
1381    def alias_sql(self, e: exp.Alias) -> None:
1382        self.stack.extend(
1383            (
1384                e.args.get("alias"),
1385                " AS ",
1386                e.args.get("this"),
1387            )
1388        )
1389
1390    def and_sql(self, e: exp.And) -> None:
1391        self._binary(e, " AND ")
1392
1393    def anonymous_sql(self, e: exp.Anonymous) -> None:
1394        this = e.this
1395        if isinstance(this, str):
1396            name = this.upper()
1397        elif isinstance(this, exp.Identifier):
1398            name = this.this
1399            name = f'"{name}"' if this.quoted else name.upper()
1400        else:
1401            raise ValueError(
1402                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1403            )
1404
1405        self.stack.extend(
1406            (
1407                ")",
1408                e.expressions,
1409                "(",
1410                name,
1411            )
1412        )
1413
1414    def between_sql(self, e: exp.Between) -> None:
1415        self.stack.extend(
1416            (
1417                e.args.get("high"),
1418                " AND ",
1419                e.args.get("low"),
1420                " BETWEEN ",
1421                e.this,
1422            )
1423        )
1424
1425    def boolean_sql(self, e: exp.Boolean) -> None:
1426        self.stack.append("TRUE" if e.this else "FALSE")
1427
1428    def bracket_sql(self, e: exp.Bracket) -> None:
1429        self.stack.extend(
1430            (
1431                "]",
1432                e.expressions,
1433                "[",
1434                e.this,
1435            )
1436        )
1437
1438    def column_sql(self, e: exp.Column) -> None:
1439        for p in reversed(e.parts):
1440            self.stack.extend((p, "."))
1441        self.stack.pop()
1442
1443    def datatype_sql(self, e: exp.DataType) -> None:
1444        self._args(e, 1)
1445        self.stack.append(f"{e.this.name} ")
1446
1447    def div_sql(self, e: exp.Div) -> None:
1448        self._binary(e, " / ")
1449
1450    def dot_sql(self, e: exp.Dot) -> None:
1451        self._binary(e, ".")
1452
1453    def eq_sql(self, e: exp.EQ) -> None:
1454        self._binary(e, " = ")
1455
1456    def from_sql(self, e: exp.From) -> None:
1457        self.stack.extend((e.this, "FROM "))
1458
1459    def gt_sql(self, e: exp.GT) -> None:
1460        self._binary(e, " > ")
1461
1462    def gte_sql(self, e: exp.GTE) -> None:
1463        self._binary(e, " >= ")
1464
1465    def identifier_sql(self, e: exp.Identifier) -> None:
1466        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1467
1468    def ilike_sql(self, e: exp.ILike) -> None:
1469        self._binary(e, " ILIKE ")
1470
1471    def in_sql(self, e: exp.In) -> None:
1472        self.stack.append(")")
1473        self._args(e, 1)
1474        self.stack.extend(
1475            (
1476                "(",
1477                " IN ",
1478                e.this,
1479            )
1480        )
1481
1482    def intdiv_sql(self, e: exp.IntDiv) -> None:
1483        self._binary(e, " DIV ")
1484
1485    def is_sql(self, e: exp.Is) -> None:
1486        self._binary(e, " IS ")
1487
1488    def like_sql(self, e: exp.Like) -> None:
1489        self._binary(e, " Like ")
1490
1491    def literal_sql(self, e: exp.Literal) -> None:
1492        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1493
1494    def lt_sql(self, e: exp.LT) -> None:
1495        self._binary(e, " < ")
1496
1497    def lte_sql(self, e: exp.LTE) -> None:
1498        self._binary(e, " <= ")
1499
1500    def mod_sql(self, e: exp.Mod) -> None:
1501        self._binary(e, " % ")
1502
1503    def mul_sql(self, e: exp.Mul) -> None:
1504        self._binary(e, " * ")
1505
1506    def neg_sql(self, e: exp.Neg) -> None:
1507        self._unary(e, "-")
1508
1509    def neq_sql(self, e: exp.NEQ) -> None:
1510        self._binary(e, " <> ")
1511
1512    def not_sql(self, e: exp.Not) -> None:
1513        self._unary(e, "NOT ")
1514
1515    def null_sql(self, e: exp.Null) -> None:
1516        self.stack.append("NULL")
1517
1518    def or_sql(self, e: exp.Or) -> None:
1519        self._binary(e, " OR ")
1520
1521    def paren_sql(self, e: exp.Paren) -> None:
1522        self.stack.extend(
1523            (
1524                ")",
1525                e.this,
1526                "(",
1527            )
1528        )
1529
1530    def sub_sql(self, e: exp.Sub) -> None:
1531        self._binary(e, " - ")
1532
1533    def subquery_sql(self, e: exp.Subquery) -> None:
1534        self._args(e, 2)
1535        alias = e.args.get("alias")
1536        if alias:
1537            self.stack.append(alias)
1538        self.stack.extend((")", e.this, "("))
1539
1540    def table_sql(self, e: exp.Table) -> None:
1541        self._args(e, 4)
1542        alias = e.args.get("alias")
1543        if alias:
1544            self.stack.append(alias)
1545        for p in reversed(e.parts):
1546            self.stack.extend((p, "."))
1547        self.stack.pop()
1548
1549    def tablealias_sql(self, e: exp.TableAlias) -> None:
1550        columns = e.columns
1551
1552        if columns:
1553            self.stack.extend((")", columns, "("))
1554
1555        self.stack.extend((e.this, " AS "))
1556
1557    def var_sql(self, e: exp.Var) -> None:
1558        self.stack.append(e.this)
1559
1560    def _binary(self, e: exp.Binary, op: str) -> None:
1561        self.stack.extend((e.expression, op, e.this))
1562
1563    def _unary(self, e: exp.Unary, op: str) -> None:
1564        self.stack.extend((e.this, op))
1565
1566    def _function(self, e: exp.Func) -> None:
1567        self.stack.extend(
1568            (
1569                ")",
1570                list(e.args.values()),
1571                "(",
1572                e.sql_name(),
1573            )
1574        )
1575
1576    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1577        kvs = []
1578        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1579
1580        for k in arg_types or arg_types:
1581            v = node.args.get(k)
1582
1583            if v is not None:
1584                kvs.append([f":{k}", v])
1585        if kvs:
1586            self.stack.append(kvs)
1587            return True
1588        return False
stack
sqls
def gen( self, expression: sqlglot.expressions.Expression, comments: bool = False) -> str:
1346    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1347        self.stack = [expression]
1348        self.sqls.clear()
1349
1350        while self.stack:
1351            node = self.stack.pop()
1352
1353            if isinstance(node, exp.Expression):
1354                if comments and node.comments:
1355                    self.stack.append(f" /*{','.join(node.comments)}*/")
1356
1357                exp_handler_name = f"{node.key}_sql"
1358
1359                if hasattr(self, exp_handler_name):
1360                    getattr(self, exp_handler_name)(node)
1361                elif isinstance(node, exp.Func):
1362                    self._function(node)
1363                else:
1364                    key = node.key.upper()
1365                    self.stack.append(f"{key} " if self._args(node) else key)
1366            elif type(node) is list:
1367                for n in reversed(node):
1368                    if n is not None:
1369                        self.stack.extend((n, ","))
1370                if node:
1371                    self.stack.pop()
1372            else:
1373                if node is not None:
1374                    self.sqls.append(str(node))
1375
1376        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1378    def add_sql(self, e: exp.Add) -> None:
1379        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1381    def alias_sql(self, e: exp.Alias) -> None:
1382        self.stack.extend(
1383            (
1384                e.args.get("alias"),
1385                " AS ",
1386                e.args.get("this"),
1387            )
1388        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1390    def and_sql(self, e: exp.And) -> None:
1391        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1393    def anonymous_sql(self, e: exp.Anonymous) -> None:
1394        this = e.this
1395        if isinstance(this, str):
1396            name = this.upper()
1397        elif isinstance(this, exp.Identifier):
1398            name = this.this
1399            name = f'"{name}"' if this.quoted else name.upper()
1400        else:
1401            raise ValueError(
1402                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1403            )
1404
1405        self.stack.extend(
1406            (
1407                ")",
1408                e.expressions,
1409                "(",
1410                name,
1411            )
1412        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1414    def between_sql(self, e: exp.Between) -> None:
1415        self.stack.extend(
1416            (
1417                e.args.get("high"),
1418                " AND ",
1419                e.args.get("low"),
1420                " BETWEEN ",
1421                e.this,
1422            )
1423        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1425    def boolean_sql(self, e: exp.Boolean) -> None:
1426        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1428    def bracket_sql(self, e: exp.Bracket) -> None:
1429        self.stack.extend(
1430            (
1431                "]",
1432                e.expressions,
1433                "[",
1434                e.this,
1435            )
1436        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1438    def column_sql(self, e: exp.Column) -> None:
1439        for p in reversed(e.parts):
1440            self.stack.extend((p, "."))
1441        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1443    def datatype_sql(self, e: exp.DataType) -> None:
1444        self._args(e, 1)
1445        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1447    def div_sql(self, e: exp.Div) -> None:
1448        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1450    def dot_sql(self, e: exp.Dot) -> None:
1451        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1453    def eq_sql(self, e: exp.EQ) -> None:
1454        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1456    def from_sql(self, e: exp.From) -> None:
1457        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1459    def gt_sql(self, e: exp.GT) -> None:
1460        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1462    def gte_sql(self, e: exp.GTE) -> None:
1463        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1465    def identifier_sql(self, e: exp.Identifier) -> None:
1466        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1468    def ilike_sql(self, e: exp.ILike) -> None:
1469        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1471    def in_sql(self, e: exp.In) -> None:
1472        self.stack.append(")")
1473        self._args(e, 1)
1474        self.stack.extend(
1475            (
1476                "(",
1477                " IN ",
1478                e.this,
1479            )
1480        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1482    def intdiv_sql(self, e: exp.IntDiv) -> None:
1483        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1485    def is_sql(self, e: exp.Is) -> None:
1486        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1488    def like_sql(self, e: exp.Like) -> None:
1489        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1491    def literal_sql(self, e: exp.Literal) -> None:
1492        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1494    def lt_sql(self, e: exp.LT) -> None:
1495        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1497    def lte_sql(self, e: exp.LTE) -> None:
1498        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1500    def mod_sql(self, e: exp.Mod) -> None:
1501        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1503    def mul_sql(self, e: exp.Mul) -> None:
1504        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1506    def neg_sql(self, e: exp.Neg) -> None:
1507        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1509    def neq_sql(self, e: exp.NEQ) -> None:
1510        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1512    def not_sql(self, e: exp.Not) -> None:
1513        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1515    def null_sql(self, e: exp.Null) -> None:
1516        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1518    def or_sql(self, e: exp.Or) -> None:
1519        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1521    def paren_sql(self, e: exp.Paren) -> None:
1522        self.stack.extend(
1523            (
1524                ")",
1525                e.this,
1526                "(",
1527            )
1528        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1530    def sub_sql(self, e: exp.Sub) -> None:
1531        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1533    def subquery_sql(self, e: exp.Subquery) -> None:
1534        self._args(e, 2)
1535        alias = e.args.get("alias")
1536        if alias:
1537            self.stack.append(alias)
1538        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1540    def table_sql(self, e: exp.Table) -> None:
1541        self._args(e, 4)
1542        alias = e.args.get("alias")
1543        if alias:
1544            self.stack.append(alias)
1545        for p in reversed(e.parts):
1546            self.stack.extend((p, "."))
1547        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1549    def tablealias_sql(self, e: exp.TableAlias) -> None:
1550        columns = e.columns
1551
1552        if columns:
1553            self.stack.extend((")", columns, "("))
1554
1555        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1557    def var_sql(self, e: exp.Var) -> None:
1558        self.stack.append(e.this)