Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23logger = logging.getLogger("sqlglot")
  24
  25# Final means that an expression should not be simplified
  26FINAL = "final"
  27
  28# Value ranges for byte-sized signed/unsigned integers
  29TINYINT_MIN = -128
  30TINYINT_MAX = 127
  31UTINYINT_MIN = 0
  32UTINYINT_MAX = 255
  33
  34
  35class UnsupportedUnit(Exception):
  36    pass
  37
  38
  39def simplify(
  40    expression: exp.Expression,
  41    constant_propagation: bool = False,
  42    coalesce_simplification: bool = False,
  43    dialect: DialectType = None,
  44):
  45    """
  46    Rewrite sqlglot AST to simplify expressions.
  47
  48    Example:
  49        >>> import sqlglot
  50        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  51        >>> simplify(expression).sql()
  52        'TRUE'
  53
  54    Args:
  55        expression: expression to simplify
  56        constant_propagation: whether the constant propagation rule should be used
  57        coalesce_simplification: whether the simplify coalesce rule should be used.
  58            This rule tries to remove coalesce functions, which can be useful in certain analyses but
  59            can leave the query more verbose.
  60    Returns:
  61        sqlglot.Expression: simplified expression
  62    """
  63
  64    dialect = Dialect.get_or_raise(dialect)
  65
  66    def _simplify(expression):
  67        pre_transformation_stack = [expression]
  68        post_transformation_stack = []
  69
  70        while pre_transformation_stack:
  71            node = pre_transformation_stack.pop()
  72
  73            if node.meta.get(FINAL):
  74                continue
  75
  76            # group by expressions cannot be simplified, for example
  77            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  78            # the projection must exactly match the group by key
  79            group = node.args.get("group")
  80
  81            if group and hasattr(node, "selects"):
  82                groups = set(group.expressions)
  83                group.meta[FINAL] = True
  84
  85                for s in node.selects:
  86                    for n in s.walk():
  87                        if n in groups:
  88                            s.meta[FINAL] = True
  89                            break
  90
  91                having = node.args.get("having")
  92                if having:
  93                    for n in having.walk():
  94                        if n in groups:
  95                            having.meta[FINAL] = True
  96                            break
  97
  98            parent = node.parent
  99            root = node is expression
 100
 101            new_node = rewrite_between(node)
 102            new_node = uniq_sort(new_node, root)
 103            new_node = absorb_and_eliminate(new_node, root)
 104            new_node = simplify_concat(new_node)
 105            new_node = simplify_conditionals(new_node)
 106
 107            if constant_propagation:
 108                new_node = propagate_constants(new_node, root)
 109
 110            if new_node is not node:
 111                node.replace(new_node)
 112
 113            pre_transformation_stack.extend(
 114                n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
 115            )
 116            post_transformation_stack.append((new_node, parent))
 117
 118        while post_transformation_stack:
 119            node, parent = post_transformation_stack.pop()
 120            root = node is expression
 121
 122            # Resets parent, arg_key, index pointers– this is needed because some of the
 123            # previous transformations mutate the AST, leading to an inconsistent state
 124            for k, v in tuple(node.args.items()):
 125                node.set(k, v)
 126
 127            # Post-order transformations
 128            new_node = simplify_not(node)
 129            new_node = flatten(new_node)
 130            new_node = simplify_connectors(new_node, root)
 131            new_node = remove_complements(new_node, root)
 132
 133            if coalesce_simplification:
 134                new_node = simplify_coalesce(new_node, dialect)
 135
 136            new_node.parent = parent
 137
 138            new_node = simplify_literals(new_node, root)
 139            new_node = simplify_equality(new_node)
 140            new_node = simplify_parens(new_node, dialect)
 141            new_node = simplify_datetrunc(new_node, dialect)
 142            new_node = sort_comparison(new_node)
 143            new_node = simplify_startswith(new_node)
 144
 145            if new_node is not node:
 146                node.replace(new_node)
 147
 148        return new_node
 149
 150    expression = while_changing(expression, _simplify)
 151    remove_where_true(expression)
 152    return expression
 153
 154
 155def catch(*exceptions):
 156    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 157
 158    def decorator(func):
 159        def wrapped(expression, *args, **kwargs):
 160            try:
 161                return func(expression, *args, **kwargs)
 162            except exceptions:
 163                return expression
 164
 165        return wrapped
 166
 167    return decorator
 168
 169
 170def rewrite_between(expression: exp.Expression) -> exp.Expression:
 171    """Rewrite x between y and z to x >= y AND x <= z.
 172
 173    This is done because comparison simplification is only done on lt/lte/gt/gte.
 174    """
 175    if isinstance(expression, exp.Between):
 176        negate = isinstance(expression.parent, exp.Not)
 177
 178        expression = exp.and_(
 179            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 180            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 181            copy=False,
 182        )
 183
 184        if negate:
 185            expression = exp.paren(expression, copy=False)
 186
 187    return expression
 188
 189
 190COMPLEMENT_COMPARISONS = {
 191    exp.LT: exp.GTE,
 192    exp.GT: exp.LTE,
 193    exp.LTE: exp.GT,
 194    exp.GTE: exp.LT,
 195    exp.EQ: exp.NEQ,
 196    exp.NEQ: exp.EQ,
 197}
 198
 199COMPLEMENT_SUBQUERY_PREDICATES = {
 200    exp.All: exp.Any,
 201    exp.Any: exp.All,
 202}
 203
 204
 205def simplify_not(expression):
 206    """
 207    Demorgan's Law
 208    NOT (x OR y) -> NOT x AND NOT y
 209    NOT (x AND y) -> NOT x OR NOT y
 210    """
 211    if isinstance(expression, exp.Not):
 212        this = expression.this
 213        if is_null(this):
 214            return exp.null()
 215        if this.__class__ in COMPLEMENT_COMPARISONS:
 216            right = this.expression
 217            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
 218            if complement_subquery_predicate:
 219                right = complement_subquery_predicate(this=right.this)
 220
 221            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 222        if isinstance(this, exp.Paren):
 223            condition = this.unnest()
 224            if isinstance(condition, exp.And):
 225                return exp.paren(
 226                    exp.or_(
 227                        exp.not_(condition.left, copy=False),
 228                        exp.not_(condition.right, copy=False),
 229                        copy=False,
 230                    )
 231                )
 232            if isinstance(condition, exp.Or):
 233                return exp.paren(
 234                    exp.and_(
 235                        exp.not_(condition.left, copy=False),
 236                        exp.not_(condition.right, copy=False),
 237                        copy=False,
 238                    )
 239                )
 240            if is_null(condition):
 241                return exp.null()
 242        if always_true(this):
 243            return exp.false()
 244        if is_false(this):
 245            return exp.true()
 246        if isinstance(this, exp.Not):
 247            # double negation
 248            # NOT NOT x -> x
 249            return this.this
 250    return expression
 251
 252
 253def flatten(expression):
 254    """
 255    A AND (B AND C) -> A AND B AND C
 256    A OR (B OR C) -> A OR B OR C
 257    """
 258    if isinstance(expression, exp.Connector):
 259        for node in expression.args.values():
 260            child = node.unnest()
 261            if isinstance(child, expression.__class__):
 262                node.replace(child)
 263    return expression
 264
 265
 266def simplify_connectors(expression, root=True):
 267    def _simplify_connectors(expression, left, right):
 268        if isinstance(expression, exp.And):
 269            if is_false(left) or is_false(right):
 270                return exp.false()
 271            if is_zero(left) or is_zero(right):
 272                return exp.false()
 273            if is_null(left) or is_null(right):
 274                return exp.null()
 275            if always_true(left) and always_true(right):
 276                return exp.true()
 277            if always_true(left):
 278                return right
 279            if always_true(right):
 280                return left
 281            return _simplify_comparison(expression, left, right)
 282        elif isinstance(expression, exp.Or):
 283            if always_true(left) or always_true(right):
 284                return exp.true()
 285            if (
 286                (is_null(left) and is_null(right))
 287                or (is_null(left) and always_false(right))
 288                or (always_false(left) and is_null(right))
 289            ):
 290                return exp.null()
 291            if is_false(left):
 292                return right
 293            if is_false(right):
 294                return left
 295            return _simplify_comparison(expression, left, right, or_=True)
 296        elif isinstance(expression, exp.Xor):
 297            if left == right:
 298                return exp.false()
 299
 300    if isinstance(expression, exp.Connector):
 301        return _flat_simplify(expression, _simplify_connectors, root)
 302    return expression
 303
 304
 305LT_LTE = (exp.LT, exp.LTE)
 306GT_GTE = (exp.GT, exp.GTE)
 307
 308COMPARISONS = (
 309    *LT_LTE,
 310    *GT_GTE,
 311    exp.EQ,
 312    exp.NEQ,
 313    exp.Is,
 314)
 315
 316INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 317    exp.LT: exp.GT,
 318    exp.GT: exp.LT,
 319    exp.LTE: exp.GTE,
 320    exp.GTE: exp.LTE,
 321}
 322
 323NONDETERMINISTIC = (exp.Rand, exp.Randn)
 324AND_OR = (exp.And, exp.Or)
 325
 326
 327def _simplify_comparison(expression, left, right, or_=False):
 328    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 329        ll, lr = left.args.values()
 330        rl, rr = right.args.values()
 331
 332        largs = {ll, lr}
 333        rargs = {rl, rr}
 334
 335        matching = largs & rargs
 336        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 337
 338        if matching and columns:
 339            try:
 340                l = first(largs - columns)
 341                r = first(rargs - columns)
 342            except StopIteration:
 343                return expression
 344
 345            if l.is_number and r.is_number:
 346                l = l.to_py()
 347                r = r.to_py()
 348            elif l.is_string and r.is_string:
 349                l = l.name
 350                r = r.name
 351            else:
 352                l = extract_date(l)
 353                if not l:
 354                    return None
 355                r = extract_date(r)
 356                if not r:
 357                    return None
 358                # python won't compare date and datetime, but many engines will upcast
 359                l, r = cast_as_datetime(l), cast_as_datetime(r)
 360
 361            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 362                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 363                    return left if (av > bv if or_ else av <= bv) else right
 364                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 365                    return left if (av < bv if or_ else av >= bv) else right
 366
 367                # we can't ever shortcut to true because the column could be null
 368                if not or_:
 369                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 370                        if av <= bv:
 371                            return exp.false()
 372                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 373                        if av >= bv:
 374                            return exp.false()
 375                    elif isinstance(a, exp.EQ):
 376                        if isinstance(b, exp.LT):
 377                            return exp.false() if av >= bv else a
 378                        if isinstance(b, exp.LTE):
 379                            return exp.false() if av > bv else a
 380                        if isinstance(b, exp.GT):
 381                            return exp.false() if av <= bv else a
 382                        if isinstance(b, exp.GTE):
 383                            return exp.false() if av < bv else a
 384                        if isinstance(b, exp.NEQ):
 385                            return exp.false() if av == bv else a
 386    return None
 387
 388
 389def remove_complements(expression, root=True):
 390    """
 391    Removing complements.
 392
 393    A AND NOT A -> FALSE
 394    A OR NOT A -> TRUE
 395    """
 396    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 397        ops = set(expression.flatten())
 398        for op in ops:
 399            if isinstance(op, exp.Not) and op.this in ops:
 400                return exp.false() if isinstance(expression, exp.And) else exp.true()
 401
 402    return expression
 403
 404
 405def uniq_sort(expression, root=True):
 406    """
 407    Uniq and sort a connector.
 408
 409    C AND A AND B AND B -> A AND B AND C
 410    """
 411    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 412        flattened = tuple(expression.flatten())
 413
 414        if isinstance(expression, exp.Xor):
 415            result_func = exp.xor
 416            # Do not deduplicate XOR as A XOR A != A if A == True
 417            deduped = None
 418            arr = tuple((gen(e), e) for e in flattened)
 419        else:
 420            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 421            deduped = {gen(e): e for e in flattened}
 422            arr = tuple(deduped.items())
 423
 424        # check if the operands are already sorted, if not sort them
 425        # A AND C AND B -> A AND B AND C
 426        for i, (sql, e) in enumerate(arr[1:]):
 427            if sql < arr[i][0]:
 428                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 429                break
 430        else:
 431            # we didn't have to sort but maybe we need to dedup
 432            if deduped and len(deduped) < len(flattened):
 433                expression = result_func(*deduped.values(), copy=False)
 434
 435    return expression
 436
 437
 438def absorb_and_eliminate(expression, root=True):
 439    """
 440    absorption:
 441        A AND (A OR B) -> A
 442        A OR (A AND B) -> A
 443        A AND (NOT A OR B) -> A AND B
 444        A OR (NOT A AND B) -> A OR B
 445    elimination:
 446        (A AND B) OR (A AND NOT B) -> A
 447        (A OR B) AND (A OR NOT B) -> A
 448    """
 449    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 450        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 451
 452        ops = tuple(expression.flatten())
 453
 454        # Initialize lookup tables:
 455        # Set of all operands, used to find complements for absorption.
 456        op_set = set()
 457        # Sub-operands, used to find subsets for absorption.
 458        subops = defaultdict(list)
 459        # Pairs of complements, used for elimination.
 460        pairs = defaultdict(list)
 461
 462        # Populate the lookup tables
 463        for op in ops:
 464            op_set.add(op)
 465
 466            if not isinstance(op, kind):
 467                # In cases like: A OR (A AND B)
 468                # Subop will be: ^
 469                subops[op].append({op})
 470                continue
 471
 472            # In cases like: (A AND B) OR (A AND B AND C)
 473            # Subops will be: ^     ^
 474            subset = set(op.flatten())
 475            for i in subset:
 476                subops[i].append(subset)
 477
 478            a, b = op.unnest_operands()
 479            if isinstance(a, exp.Not):
 480                pairs[frozenset((a.this, b))].append((op, b))
 481            if isinstance(b, exp.Not):
 482                pairs[frozenset((a, b.this))].append((op, a))
 483
 484        for op in ops:
 485            if not isinstance(op, kind):
 486                continue
 487
 488            a, b = op.unnest_operands()
 489
 490            # Absorb
 491            if isinstance(a, exp.Not) and a.this in op_set:
 492                a.replace(exp.true() if kind == exp.And else exp.false())
 493                continue
 494            if isinstance(b, exp.Not) and b.this in op_set:
 495                b.replace(exp.true() if kind == exp.And else exp.false())
 496                continue
 497            superset = set(op.flatten())
 498            if any(any(subset < superset for subset in subops[i]) for i in superset):
 499                op.replace(exp.false() if kind == exp.And else exp.true())
 500                continue
 501
 502            # Eliminate
 503            for other, complement in pairs[frozenset((a, b))]:
 504                op.replace(complement)
 505                other.replace(complement)
 506
 507    return expression
 508
 509
 510def propagate_constants(expression, root=True):
 511    """
 512    Propagate constants for conjunctions in DNF:
 513
 514    SELECT * FROM t WHERE a = b AND b = 5 becomes
 515    SELECT * FROM t WHERE a = 5 AND b = 5
 516
 517    Reference: https://www.sqlite.org/optoverview.html
 518    """
 519
 520    if (
 521        isinstance(expression, exp.And)
 522        and (root or not expression.same_parent)
 523        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 524    ):
 525        constant_mapping = {}
 526        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 527            if isinstance(expr, exp.EQ):
 528                l, r = expr.left, expr.right
 529
 530                # TODO: create a helper that can be used to detect nested literal expressions such
 531                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 532                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 533                    constant_mapping[l] = (id(l), r)
 534
 535        if constant_mapping:
 536            for column in find_all_in_scope(expression, exp.Column):
 537                parent = column.parent
 538                column_id, constant = constant_mapping.get(column) or (None, None)
 539                if (
 540                    column_id is not None
 541                    and id(column) != column_id
 542                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 543                ):
 544                    column.replace(constant.copy())
 545
 546    return expression
 547
 548
 549INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 550    exp.DateAdd: exp.Sub,
 551    exp.DateSub: exp.Add,
 552    exp.DatetimeAdd: exp.Sub,
 553    exp.DatetimeSub: exp.Add,
 554}
 555
 556INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 557    **INVERSE_DATE_OPS,
 558    exp.Add: exp.Sub,
 559    exp.Sub: exp.Add,
 560}
 561
 562
 563def _is_number(expression: exp.Expression) -> bool:
 564    return expression.is_number
 565
 566
 567def _is_interval(expression: exp.Expression) -> bool:
 568    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 569
 570
 571@catch(ModuleNotFoundError, UnsupportedUnit)
 572def simplify_equality(expression: exp.Expression) -> exp.Expression:
 573    """
 574    Use the subtraction and addition properties of equality to simplify expressions:
 575
 576        x + 1 = 3 becomes x = 2
 577
 578    There are two binary operations in the above expression: + and =
 579    Here's how we reference all the operands in the code below:
 580
 581          l     r
 582        x + 1 = 3
 583        a   b
 584    """
 585    if isinstance(expression, COMPARISONS):
 586        l, r = expression.left, expression.right
 587
 588        if l.__class__ not in INVERSE_OPS:
 589            return expression
 590
 591        if r.is_number:
 592            a_predicate = _is_number
 593            b_predicate = _is_number
 594        elif _is_date_literal(r):
 595            a_predicate = _is_date_literal
 596            b_predicate = _is_interval
 597        else:
 598            return expression
 599
 600        if l.__class__ in INVERSE_DATE_OPS:
 601            l = t.cast(exp.IntervalOp, l)
 602            a = l.this
 603            b = l.interval()
 604        else:
 605            l = t.cast(exp.Binary, l)
 606            a, b = l.left, l.right
 607
 608        if not a_predicate(a) and b_predicate(b):
 609            pass
 610        elif not a_predicate(b) and b_predicate(a):
 611            a, b = b, a
 612        else:
 613            return expression
 614
 615        return expression.__class__(
 616            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 617        )
 618    return expression
 619
 620
 621def simplify_literals(expression, root=True):
 622    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 623        return _flat_simplify(expression, _simplify_binary, root)
 624
 625    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
 626        return expression.this.this
 627
 628    if type(expression) in INVERSE_DATE_OPS:
 629        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 630
 631    return expression
 632
 633
 634NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 635
 636
 637def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 638    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 639        this = _simplify_integer_cast(expr.this)
 640    else:
 641        this = expr.this
 642
 643    if isinstance(expr, exp.Cast) and this.is_int:
 644        num = this.to_py()
 645
 646        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 647        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 648        # engine-dependent
 649        if (
 650            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 651        ) or (
 652            UTINYINT_MIN <= num <= UTINYINT_MAX
 653            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 654        ):
 655            return this
 656
 657    return expr
 658
 659
 660def _simplify_binary(expression, a, b):
 661    if isinstance(expression, COMPARISONS):
 662        a = _simplify_integer_cast(a)
 663        b = _simplify_integer_cast(b)
 664
 665    if isinstance(expression, exp.Is):
 666        if isinstance(b, exp.Not):
 667            c = b.this
 668            not_ = True
 669        else:
 670            c = b
 671            not_ = False
 672
 673        if is_null(c):
 674            if isinstance(a, exp.Literal):
 675                return exp.true() if not_ else exp.false()
 676            if is_null(a):
 677                return exp.false() if not_ else exp.true()
 678    elif isinstance(expression, NULL_OK):
 679        return None
 680    elif is_null(a) or is_null(b):
 681        return exp.null()
 682
 683    if a.is_number and b.is_number:
 684        num_a = a.to_py()
 685        num_b = b.to_py()
 686
 687        if isinstance(expression, exp.Add):
 688            return exp.Literal.number(num_a + num_b)
 689        if isinstance(expression, exp.Mul):
 690            return exp.Literal.number(num_a * num_b)
 691
 692        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 693        if isinstance(expression, exp.Sub):
 694            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 695        if isinstance(expression, exp.Div):
 696            # engines have differing int div behavior so intdiv is not safe
 697            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 698                return None
 699            return exp.Literal.number(num_a / num_b)
 700
 701        boolean = eval_boolean(expression, num_a, num_b)
 702
 703        if boolean:
 704            return boolean
 705    elif a.is_string and b.is_string:
 706        boolean = eval_boolean(expression, a.this, b.this)
 707
 708        if boolean:
 709            return boolean
 710    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 711        date, b = extract_date(a), extract_interval(b)
 712        if date and b:
 713            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 714                return date_literal(date + b, extract_type(a))
 715            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 716                return date_literal(date - b, extract_type(a))
 717    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 718        a, date = extract_interval(a), extract_date(b)
 719        # you cannot subtract a date from an interval
 720        if a and b and isinstance(expression, exp.Add):
 721            return date_literal(a + date, extract_type(b))
 722    elif _is_date_literal(a) and _is_date_literal(b):
 723        if isinstance(expression, exp.Predicate):
 724            a, b = extract_date(a), extract_date(b)
 725            boolean = eval_boolean(expression, a, b)
 726            if boolean:
 727                return boolean
 728
 729    return None
 730
 731
 732def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
 733    if not isinstance(expression, exp.Paren):
 734        return expression
 735
 736    this = expression.this
 737    parent = expression.parent
 738    parent_is_predicate = isinstance(parent, exp.Predicate)
 739
 740    if isinstance(this, exp.Select):
 741        return expression
 742
 743    if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
 744        return expression
 745
 746    # Handle risingwave struct columns
 747    # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
 748    if (
 749        dialect == "risingwave"
 750        and isinstance(parent, exp.Dot)
 751        and (isinstance(parent.right, (exp.Identifier, exp.Star)))
 752    ):
 753        return expression
 754
 755    if (
 756        not isinstance(parent, (exp.Condition, exp.Binary))
 757        or isinstance(parent, exp.Paren)
 758        or (
 759            not isinstance(this, exp.Binary)
 760            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 761        )
 762        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 763        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 764        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 765        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 766    ):
 767        return this
 768
 769    return expression
 770
 771
 772def _is_nonnull_constant(expression: exp.Expression) -> bool:
 773    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 774
 775
 776def _is_constant(expression: exp.Expression) -> bool:
 777    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 778
 779
 780def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
 781    # COALESCE(x) -> x
 782    if (
 783        isinstance(expression, exp.Coalesce)
 784        and (not expression.expressions or _is_nonnull_constant(expression.this))
 785        # COALESCE is also used as a Spark partitioning hint
 786        and not isinstance(expression.parent, exp.Hint)
 787    ):
 788        return expression.this
 789
 790    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
 791    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
 792    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
 793    if dialect == "redshift":
 794        return expression
 795
 796    if not isinstance(expression, COMPARISONS):
 797        return expression
 798
 799    if isinstance(expression.left, exp.Coalesce):
 800        coalesce = expression.left
 801        other = expression.right
 802    elif isinstance(expression.right, exp.Coalesce):
 803        coalesce = expression.right
 804        other = expression.left
 805    else:
 806        return expression
 807
 808    # This transformation is valid for non-constants,
 809    # but it really only does anything if they are both constants.
 810    if not _is_constant(other):
 811        return expression
 812
 813    # Find the first constant arg
 814    for arg_index, arg in enumerate(coalesce.expressions):
 815        if _is_constant(arg):
 816            break
 817    else:
 818        return expression
 819
 820    coalesce.set("expressions", coalesce.expressions[:arg_index])
 821
 822    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 823    # since we already remove COALESCE at the top of this function.
 824    coalesce = coalesce if coalesce.expressions else coalesce.this
 825
 826    # This expression is more complex than when we started, but it will get simplified further
 827    return exp.paren(
 828        exp.or_(
 829            exp.and_(
 830                coalesce.is_(exp.null()).not_(copy=False),
 831                expression.copy(),
 832                copy=False,
 833            ),
 834            exp.and_(
 835                coalesce.is_(exp.null()),
 836                type(expression)(this=arg.copy(), expression=other.copy()),
 837                copy=False,
 838            ),
 839            copy=False,
 840        )
 841    )
 842
 843
 844CONCATS = (exp.Concat, exp.DPipe)
 845
 846
 847def simplify_concat(expression):
 848    """Reduces all groups that contain string literals by concatenating them."""
 849    if not isinstance(expression, CONCATS) or (
 850        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 851        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 852    ):
 853        return expression
 854
 855    if isinstance(expression, exp.ConcatWs):
 856        sep_expr, *expressions = expression.expressions
 857        sep = sep_expr.name
 858        concat_type = exp.ConcatWs
 859        args = {}
 860    else:
 861        expressions = expression.expressions
 862        sep = ""
 863        concat_type = exp.Concat
 864        args = {
 865            "safe": expression.args.get("safe"),
 866            "coalesce": expression.args.get("coalesce"),
 867        }
 868
 869    new_args = []
 870    for is_string_group, group in itertools.groupby(
 871        expressions or expression.flatten(), lambda e: e.is_string
 872    ):
 873        if is_string_group:
 874            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 875        else:
 876            new_args.extend(group)
 877
 878    if len(new_args) == 1 and new_args[0].is_string:
 879        return new_args[0]
 880
 881    if concat_type is exp.ConcatWs:
 882        new_args = [sep_expr] + new_args
 883    elif isinstance(expression, exp.DPipe):
 884        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 885
 886    return concat_type(expressions=new_args, **args)
 887
 888
 889def simplify_conditionals(expression):
 890    """Simplifies expressions like IF, CASE if their condition is statically known."""
 891    if isinstance(expression, exp.Case):
 892        this = expression.this
 893        for case in expression.args["ifs"]:
 894            cond = case.this
 895            if this:
 896                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 897                cond = cond.replace(this.pop().eq(cond))
 898
 899            if always_true(cond):
 900                return case.args["true"]
 901
 902            if always_false(cond):
 903                case.pop()
 904                if not expression.args["ifs"]:
 905                    return expression.args.get("default") or exp.null()
 906    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 907        if always_true(expression.this):
 908            return expression.args["true"]
 909        if always_false(expression.this):
 910            return expression.args.get("false") or exp.null()
 911
 912    return expression
 913
 914
 915def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 916    """
 917    Reduces a prefix check to either TRUE or FALSE if both the string and the
 918    prefix are statically known.
 919
 920    Example:
 921        >>> from sqlglot import parse_one
 922        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 923        'TRUE'
 924    """
 925    if (
 926        isinstance(expression, exp.StartsWith)
 927        and expression.this.is_string
 928        and expression.expression.is_string
 929    ):
 930        return exp.convert(expression.name.startswith(expression.expression.name))
 931
 932    return expression
 933
 934
 935DateRange = t.Tuple[datetime.date, datetime.date]
 936
 937
 938def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 939    """
 940    Get the date range for a DATE_TRUNC equality comparison:
 941
 942    Example:
 943        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 944    Returns:
 945        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 946    """
 947    floor = date_floor(date, unit, dialect)
 948
 949    if date != floor:
 950        # This will always be False, except for NULL values.
 951        return None
 952
 953    return floor, floor + interval(unit)
 954
 955
 956def _datetrunc_eq_expression(
 957    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 958) -> exp.Expression:
 959    """Get the logical expression for a date range"""
 960    return exp.and_(
 961        left >= date_literal(drange[0], target_type),
 962        left < date_literal(drange[1], target_type),
 963        copy=False,
 964    )
 965
 966
 967def _datetrunc_eq(
 968    left: exp.Expression,
 969    date: datetime.date,
 970    unit: str,
 971    dialect: Dialect,
 972    target_type: t.Optional[exp.DataType],
 973) -> t.Optional[exp.Expression]:
 974    drange = _datetrunc_range(date, unit, dialect)
 975    if not drange:
 976        return None
 977
 978    return _datetrunc_eq_expression(left, drange, target_type)
 979
 980
 981def _datetrunc_neq(
 982    left: exp.Expression,
 983    date: datetime.date,
 984    unit: str,
 985    dialect: Dialect,
 986    target_type: t.Optional[exp.DataType],
 987) -> t.Optional[exp.Expression]:
 988    drange = _datetrunc_range(date, unit, dialect)
 989    if not drange:
 990        return None
 991
 992    return exp.and_(
 993        left < date_literal(drange[0], target_type),
 994        left >= date_literal(drange[1], target_type),
 995        copy=False,
 996    )
 997
 998
 999DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
1000    exp.LT: lambda l, dt, u, d, t: l
1001    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
1002    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
1003    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
1004    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
1005    exp.EQ: _datetrunc_eq,
1006    exp.NEQ: _datetrunc_neq,
1007}
1008DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
1009DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
1010
1011
1012def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1013    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1014
1015
1016@catch(ModuleNotFoundError, UnsupportedUnit)
1017def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1018    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1019    comparison = expression.__class__
1020
1021    if isinstance(expression, DATETRUNCS):
1022        this = expression.this
1023        trunc_type = extract_type(this)
1024        date = extract_date(this)
1025        if date and expression.unit:
1026            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1027    elif comparison not in DATETRUNC_COMPARISONS:
1028        return expression
1029
1030    if isinstance(expression, exp.Binary):
1031        l, r = expression.left, expression.right
1032
1033        if not _is_datetrunc_predicate(l, r):
1034            return expression
1035
1036        l = t.cast(exp.DateTrunc, l)
1037        trunc_arg = l.this
1038        unit = l.unit.name.lower()
1039        date = extract_date(r)
1040
1041        if not date:
1042            return expression
1043
1044        return (
1045            DATETRUNC_BINARY_COMPARISONS[comparison](
1046                trunc_arg, date, unit, dialect, extract_type(r)
1047            )
1048            or expression
1049        )
1050
1051    if isinstance(expression, exp.In):
1052        l = expression.this
1053        rs = expression.expressions
1054
1055        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1056            l = t.cast(exp.DateTrunc, l)
1057            unit = l.unit.name.lower()
1058
1059            ranges = []
1060            for r in rs:
1061                date = extract_date(r)
1062                if not date:
1063                    return expression
1064                drange = _datetrunc_range(date, unit, dialect)
1065                if drange:
1066                    ranges.append(drange)
1067
1068            if not ranges:
1069                return expression
1070
1071            ranges = merge_ranges(ranges)
1072            target_type = extract_type(*rs)
1073
1074            return exp.or_(
1075                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1076            )
1077
1078    return expression
1079
1080
1081def sort_comparison(expression: exp.Expression) -> exp.Expression:
1082    if expression.__class__ in COMPLEMENT_COMPARISONS:
1083        l, r = expression.this, expression.expression
1084        l_column = isinstance(l, exp.Column)
1085        r_column = isinstance(r, exp.Column)
1086        l_const = _is_constant(l)
1087        r_const = _is_constant(r)
1088
1089        if (
1090            (l_column and not r_column)
1091            or (r_const and not l_const)
1092            or isinstance(r, exp.SubqueryPredicate)
1093        ):
1094            return expression
1095        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1096            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1097                this=r, expression=l
1098            )
1099    return expression
1100
1101
1102# CROSS joins result in an empty table if the right table is empty.
1103# So we can only simplify certain types of joins to CROSS.
1104# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1105JOINS = {
1106    ("", ""),
1107    ("", "INNER"),
1108    ("RIGHT", ""),
1109    ("RIGHT", "OUTER"),
1110}
1111
1112
1113def remove_where_true(expression):
1114    for where in expression.find_all(exp.Where):
1115        if always_true(where.this):
1116            where.pop()
1117    for join in expression.find_all(exp.Join):
1118        if (
1119            always_true(join.args.get("on"))
1120            and not join.args.get("using")
1121            and not join.args.get("method")
1122            and (join.side, join.kind) in JOINS
1123        ):
1124            join.args["on"].pop()
1125            join.set("side", None)
1126            join.set("kind", "CROSS")
1127
1128
1129def always_true(expression):
1130    return (isinstance(expression, exp.Boolean) and expression.this) or (
1131        isinstance(expression, exp.Literal) and not is_zero(expression)
1132    )
1133
1134
1135def always_false(expression):
1136    return is_false(expression) or is_null(expression) or is_zero(expression)
1137
1138
1139def is_zero(expression):
1140    return isinstance(expression, exp.Literal) and expression.to_py() == 0
1141
1142
1143def is_complement(a, b):
1144    return isinstance(b, exp.Not) and b.this == a
1145
1146
1147def is_false(a: exp.Expression) -> bool:
1148    return type(a) is exp.Boolean and not a.this
1149
1150
1151def is_null(a: exp.Expression) -> bool:
1152    return type(a) is exp.Null
1153
1154
1155def eval_boolean(expression, a, b):
1156    if isinstance(expression, (exp.EQ, exp.Is)):
1157        return boolean_literal(a == b)
1158    if isinstance(expression, exp.NEQ):
1159        return boolean_literal(a != b)
1160    if isinstance(expression, exp.GT):
1161        return boolean_literal(a > b)
1162    if isinstance(expression, exp.GTE):
1163        return boolean_literal(a >= b)
1164    if isinstance(expression, exp.LT):
1165        return boolean_literal(a < b)
1166    if isinstance(expression, exp.LTE):
1167        return boolean_literal(a <= b)
1168    return None
1169
1170
1171def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1172    if isinstance(value, datetime.datetime):
1173        return value.date()
1174    if isinstance(value, datetime.date):
1175        return value
1176    try:
1177        return datetime.datetime.fromisoformat(value).date()
1178    except ValueError:
1179        return None
1180
1181
1182def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1183    if isinstance(value, datetime.datetime):
1184        return value
1185    if isinstance(value, datetime.date):
1186        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1187    try:
1188        return datetime.datetime.fromisoformat(value)
1189    except ValueError:
1190        return None
1191
1192
1193def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1194    if not value:
1195        return None
1196    if to.is_type(exp.DataType.Type.DATE):
1197        return cast_as_date(value)
1198    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1199        return cast_as_datetime(value)
1200    return None
1201
1202
1203def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1204    if isinstance(cast, exp.Cast):
1205        to = cast.to
1206    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1207        to = exp.DataType.build(exp.DataType.Type.DATE)
1208    else:
1209        return None
1210
1211    if isinstance(cast.this, exp.Literal):
1212        value: t.Any = cast.this.name
1213    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1214        value = extract_date(cast.this)
1215    else:
1216        return None
1217    return cast_value(value, to)
1218
1219
1220def _is_date_literal(expression: exp.Expression) -> bool:
1221    return extract_date(expression) is not None
1222
1223
1224def extract_interval(expression):
1225    try:
1226        n = int(expression.this.to_py())
1227        unit = expression.text("unit").lower()
1228        return interval(unit, n)
1229    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1230        return None
1231
1232
1233def extract_type(*expressions):
1234    target_type = None
1235    for expression in expressions:
1236        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1237        if target_type:
1238            break
1239
1240    return target_type
1241
1242
1243def date_literal(date, target_type=None):
1244    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1245        target_type = (
1246            exp.DataType.Type.DATETIME
1247            if isinstance(date, datetime.datetime)
1248            else exp.DataType.Type.DATE
1249        )
1250
1251    return exp.cast(exp.Literal.string(date), target_type)
1252
1253
1254def interval(unit: str, n: int = 1):
1255    from dateutil.relativedelta import relativedelta
1256
1257    if unit == "year":
1258        return relativedelta(years=1 * n)
1259    if unit == "quarter":
1260        return relativedelta(months=3 * n)
1261    if unit == "month":
1262        return relativedelta(months=1 * n)
1263    if unit == "week":
1264        return relativedelta(weeks=1 * n)
1265    if unit == "day":
1266        return relativedelta(days=1 * n)
1267    if unit == "hour":
1268        return relativedelta(hours=1 * n)
1269    if unit == "minute":
1270        return relativedelta(minutes=1 * n)
1271    if unit == "second":
1272        return relativedelta(seconds=1 * n)
1273
1274    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1275
1276
1277def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1278    if unit == "year":
1279        return d.replace(month=1, day=1)
1280    if unit == "quarter":
1281        if d.month <= 3:
1282            return d.replace(month=1, day=1)
1283        elif d.month <= 6:
1284            return d.replace(month=4, day=1)
1285        elif d.month <= 9:
1286            return d.replace(month=7, day=1)
1287        else:
1288            return d.replace(month=10, day=1)
1289    if unit == "month":
1290        return d.replace(month=d.month, day=1)
1291    if unit == "week":
1292        # Assuming week starts on Monday (0) and ends on Sunday (6)
1293        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1294    if unit == "day":
1295        return d
1296
1297    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1298
1299
1300def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1301    floor = date_floor(d, unit, dialect)
1302
1303    if floor == d:
1304        return d
1305
1306    return floor + interval(unit)
1307
1308
1309def boolean_literal(condition):
1310    return exp.true() if condition else exp.false()
1311
1312
1313def _flat_simplify(expression, simplifier, root=True):
1314    if root or not expression.same_parent:
1315        operands = []
1316        queue = deque(expression.flatten(unnest=False))
1317        size = len(queue)
1318
1319        while queue:
1320            a = queue.popleft()
1321
1322            for b in queue:
1323                result = simplifier(expression, a, b)
1324
1325                if result and result is not expression:
1326                    queue.remove(b)
1327                    queue.appendleft(result)
1328                    break
1329            else:
1330                operands.append(a)
1331
1332        if len(operands) < size:
1333            return functools.reduce(
1334                lambda a, b: expression.__class__(this=a, expression=b), operands
1335            )
1336    return expression
1337
1338
1339def gen(expression: t.Any, comments: bool = False) -> str:
1340    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1341
1342    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1343    generator is expensive so we have a bare minimum sql generator here.
1344
1345    Args:
1346        expression: the expression to convert into a SQL string.
1347        comments: whether to include the expression's comments.
1348    """
1349    return Gen().gen(expression, comments=comments)
1350
1351
1352class Gen:
1353    def __init__(self):
1354        self.stack = []
1355        self.sqls = []
1356
1357    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1358        self.stack = [expression]
1359        self.sqls.clear()
1360
1361        while self.stack:
1362            node = self.stack.pop()
1363
1364            if isinstance(node, exp.Expression):
1365                if comments and node.comments:
1366                    self.stack.append(f" /*{','.join(node.comments)}*/")
1367
1368                exp_handler_name = f"{node.key}_sql"
1369
1370                if hasattr(self, exp_handler_name):
1371                    getattr(self, exp_handler_name)(node)
1372                elif isinstance(node, exp.Func):
1373                    self._function(node)
1374                else:
1375                    key = node.key.upper()
1376                    self.stack.append(f"{key} " if self._args(node) else key)
1377            elif type(node) is list:
1378                for n in reversed(node):
1379                    if n is not None:
1380                        self.stack.extend((n, ","))
1381                if node:
1382                    self.stack.pop()
1383            else:
1384                if node is not None:
1385                    self.sqls.append(str(node))
1386
1387        return "".join(self.sqls)
1388
1389    def add_sql(self, e: exp.Add) -> None:
1390        self._binary(e, " + ")
1391
1392    def alias_sql(self, e: exp.Alias) -> None:
1393        self.stack.extend(
1394            (
1395                e.args.get("alias"),
1396                " AS ",
1397                e.args.get("this"),
1398            )
1399        )
1400
1401    def and_sql(self, e: exp.And) -> None:
1402        self._binary(e, " AND ")
1403
1404    def anonymous_sql(self, e: exp.Anonymous) -> None:
1405        this = e.this
1406        if isinstance(this, str):
1407            name = this.upper()
1408        elif isinstance(this, exp.Identifier):
1409            name = this.this
1410            name = f'"{name}"' if this.quoted else name.upper()
1411        else:
1412            raise ValueError(
1413                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1414            )
1415
1416        self.stack.extend(
1417            (
1418                ")",
1419                e.expressions,
1420                "(",
1421                name,
1422            )
1423        )
1424
1425    def between_sql(self, e: exp.Between) -> None:
1426        self.stack.extend(
1427            (
1428                e.args.get("high"),
1429                " AND ",
1430                e.args.get("low"),
1431                " BETWEEN ",
1432                e.this,
1433            )
1434        )
1435
1436    def boolean_sql(self, e: exp.Boolean) -> None:
1437        self.stack.append("TRUE" if e.this else "FALSE")
1438
1439    def bracket_sql(self, e: exp.Bracket) -> None:
1440        self.stack.extend(
1441            (
1442                "]",
1443                e.expressions,
1444                "[",
1445                e.this,
1446            )
1447        )
1448
1449    def column_sql(self, e: exp.Column) -> None:
1450        for p in reversed(e.parts):
1451            self.stack.extend((p, "."))
1452        self.stack.pop()
1453
1454    def datatype_sql(self, e: exp.DataType) -> None:
1455        self._args(e, 1)
1456        self.stack.append(f"{e.this.name} ")
1457
1458    def div_sql(self, e: exp.Div) -> None:
1459        self._binary(e, " / ")
1460
1461    def dot_sql(self, e: exp.Dot) -> None:
1462        self._binary(e, ".")
1463
1464    def eq_sql(self, e: exp.EQ) -> None:
1465        self._binary(e, " = ")
1466
1467    def from_sql(self, e: exp.From) -> None:
1468        self.stack.extend((e.this, "FROM "))
1469
1470    def gt_sql(self, e: exp.GT) -> None:
1471        self._binary(e, " > ")
1472
1473    def gte_sql(self, e: exp.GTE) -> None:
1474        self._binary(e, " >= ")
1475
1476    def identifier_sql(self, e: exp.Identifier) -> None:
1477        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1478
1479    def ilike_sql(self, e: exp.ILike) -> None:
1480        self._binary(e, " ILIKE ")
1481
1482    def in_sql(self, e: exp.In) -> None:
1483        self.stack.append(")")
1484        self._args(e, 1)
1485        self.stack.extend(
1486            (
1487                "(",
1488                " IN ",
1489                e.this,
1490            )
1491        )
1492
1493    def intdiv_sql(self, e: exp.IntDiv) -> None:
1494        self._binary(e, " DIV ")
1495
1496    def is_sql(self, e: exp.Is) -> None:
1497        self._binary(e, " IS ")
1498
1499    def like_sql(self, e: exp.Like) -> None:
1500        self._binary(e, " Like ")
1501
1502    def literal_sql(self, e: exp.Literal) -> None:
1503        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1504
1505    def lt_sql(self, e: exp.LT) -> None:
1506        self._binary(e, " < ")
1507
1508    def lte_sql(self, e: exp.LTE) -> None:
1509        self._binary(e, " <= ")
1510
1511    def mod_sql(self, e: exp.Mod) -> None:
1512        self._binary(e, " % ")
1513
1514    def mul_sql(self, e: exp.Mul) -> None:
1515        self._binary(e, " * ")
1516
1517    def neg_sql(self, e: exp.Neg) -> None:
1518        self._unary(e, "-")
1519
1520    def neq_sql(self, e: exp.NEQ) -> None:
1521        self._binary(e, " <> ")
1522
1523    def not_sql(self, e: exp.Not) -> None:
1524        self._unary(e, "NOT ")
1525
1526    def null_sql(self, e: exp.Null) -> None:
1527        self.stack.append("NULL")
1528
1529    def or_sql(self, e: exp.Or) -> None:
1530        self._binary(e, " OR ")
1531
1532    def paren_sql(self, e: exp.Paren) -> None:
1533        self.stack.extend(
1534            (
1535                ")",
1536                e.this,
1537                "(",
1538            )
1539        )
1540
1541    def sub_sql(self, e: exp.Sub) -> None:
1542        self._binary(e, " - ")
1543
1544    def subquery_sql(self, e: exp.Subquery) -> None:
1545        self._args(e, 2)
1546        alias = e.args.get("alias")
1547        if alias:
1548            self.stack.append(alias)
1549        self.stack.extend((")", e.this, "("))
1550
1551    def table_sql(self, e: exp.Table) -> None:
1552        self._args(e, 4)
1553        alias = e.args.get("alias")
1554        if alias:
1555            self.stack.append(alias)
1556        for p in reversed(e.parts):
1557            self.stack.extend((p, "."))
1558        self.stack.pop()
1559
1560    def tablealias_sql(self, e: exp.TableAlias) -> None:
1561        columns = e.columns
1562
1563        if columns:
1564            self.stack.extend((")", columns, "("))
1565
1566        self.stack.extend((e.this, " AS "))
1567
1568    def var_sql(self, e: exp.Var) -> None:
1569        self.stack.append(e.this)
1570
1571    def _binary(self, e: exp.Binary, op: str) -> None:
1572        self.stack.extend((e.expression, op, e.this))
1573
1574    def _unary(self, e: exp.Unary, op: str) -> None:
1575        self.stack.extend((e.this, op))
1576
1577    def _function(self, e: exp.Func) -> None:
1578        self.stack.extend(
1579            (
1580                ")",
1581                list(e.args.values()),
1582                "(",
1583                e.sql_name(),
1584            )
1585        )
1586
1587    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1588        kvs = []
1589        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1590
1591        for k in arg_types or arg_types:
1592            v = node.args.get(k)
1593
1594            if v is not None:
1595                kvs.append([f":{k}", v])
1596        if kvs:
1597            self.stack.append(kvs)
1598            return True
1599        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
36class UnsupportedUnit(Exception):
37    pass

Common base class for all non-exit exceptions.

def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, coalesce_simplification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None):
 40def simplify(
 41    expression: exp.Expression,
 42    constant_propagation: bool = False,
 43    coalesce_simplification: bool = False,
 44    dialect: DialectType = None,
 45):
 46    """
 47    Rewrite sqlglot AST to simplify expressions.
 48
 49    Example:
 50        >>> import sqlglot
 51        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 52        >>> simplify(expression).sql()
 53        'TRUE'
 54
 55    Args:
 56        expression: expression to simplify
 57        constant_propagation: whether the constant propagation rule should be used
 58        coalesce_simplification: whether the simplify coalesce rule should be used.
 59            This rule tries to remove coalesce functions, which can be useful in certain analyses but
 60            can leave the query more verbose.
 61    Returns:
 62        sqlglot.Expression: simplified expression
 63    """
 64
 65    dialect = Dialect.get_or_raise(dialect)
 66
 67    def _simplify(expression):
 68        pre_transformation_stack = [expression]
 69        post_transformation_stack = []
 70
 71        while pre_transformation_stack:
 72            node = pre_transformation_stack.pop()
 73
 74            if node.meta.get(FINAL):
 75                continue
 76
 77            # group by expressions cannot be simplified, for example
 78            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 79            # the projection must exactly match the group by key
 80            group = node.args.get("group")
 81
 82            if group and hasattr(node, "selects"):
 83                groups = set(group.expressions)
 84                group.meta[FINAL] = True
 85
 86                for s in node.selects:
 87                    for n in s.walk():
 88                        if n in groups:
 89                            s.meta[FINAL] = True
 90                            break
 91
 92                having = node.args.get("having")
 93                if having:
 94                    for n in having.walk():
 95                        if n in groups:
 96                            having.meta[FINAL] = True
 97                            break
 98
 99            parent = node.parent
100            root = node is expression
101
102            new_node = rewrite_between(node)
103            new_node = uniq_sort(new_node, root)
104            new_node = absorb_and_eliminate(new_node, root)
105            new_node = simplify_concat(new_node)
106            new_node = simplify_conditionals(new_node)
107
108            if constant_propagation:
109                new_node = propagate_constants(new_node, root)
110
111            if new_node is not node:
112                node.replace(new_node)
113
114            pre_transformation_stack.extend(
115                n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
116            )
117            post_transformation_stack.append((new_node, parent))
118
119        while post_transformation_stack:
120            node, parent = post_transformation_stack.pop()
121            root = node is expression
122
123            # Resets parent, arg_key, index pointers– this is needed because some of the
124            # previous transformations mutate the AST, leading to an inconsistent state
125            for k, v in tuple(node.args.items()):
126                node.set(k, v)
127
128            # Post-order transformations
129            new_node = simplify_not(node)
130            new_node = flatten(new_node)
131            new_node = simplify_connectors(new_node, root)
132            new_node = remove_complements(new_node, root)
133
134            if coalesce_simplification:
135                new_node = simplify_coalesce(new_node, dialect)
136
137            new_node.parent = parent
138
139            new_node = simplify_literals(new_node, root)
140            new_node = simplify_equality(new_node)
141            new_node = simplify_parens(new_node, dialect)
142            new_node = simplify_datetrunc(new_node, dialect)
143            new_node = sort_comparison(new_node)
144            new_node = simplify_startswith(new_node)
145
146            if new_node is not node:
147                node.replace(new_node)
148
149        return new_node
150
151    expression = while_changing(expression, _simplify)
152    remove_where_true(expression)
153    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
156def catch(*exceptions):
157    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
158
159    def decorator(func):
160        def wrapped(expression, *args, **kwargs):
161            try:
162                return func(expression, *args, **kwargs)
163            except exceptions:
164                return expression
165
166        return wrapped
167
168    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
171def rewrite_between(expression: exp.Expression) -> exp.Expression:
172    """Rewrite x between y and z to x >= y AND x <= z.
173
174    This is done because comparison simplification is only done on lt/lte/gt/gte.
175    """
176    if isinstance(expression, exp.Between):
177        negate = isinstance(expression.parent, exp.Not)
178
179        expression = exp.and_(
180            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
181            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
182            copy=False,
183        )
184
185        if negate:
186            expression = exp.paren(expression, copy=False)
187
188    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

COMPLEMENT_SUBQUERY_PREDICATES = {<class 'sqlglot.expressions.All'>: <class 'sqlglot.expressions.Any'>, <class 'sqlglot.expressions.Any'>: <class 'sqlglot.expressions.All'>}
def simplify_not(expression):
206def simplify_not(expression):
207    """
208    Demorgan's Law
209    NOT (x OR y) -> NOT x AND NOT y
210    NOT (x AND y) -> NOT x OR NOT y
211    """
212    if isinstance(expression, exp.Not):
213        this = expression.this
214        if is_null(this):
215            return exp.null()
216        if this.__class__ in COMPLEMENT_COMPARISONS:
217            right = this.expression
218            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
219            if complement_subquery_predicate:
220                right = complement_subquery_predicate(this=right.this)
221
222            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
223        if isinstance(this, exp.Paren):
224            condition = this.unnest()
225            if isinstance(condition, exp.And):
226                return exp.paren(
227                    exp.or_(
228                        exp.not_(condition.left, copy=False),
229                        exp.not_(condition.right, copy=False),
230                        copy=False,
231                    )
232                )
233            if isinstance(condition, exp.Or):
234                return exp.paren(
235                    exp.and_(
236                        exp.not_(condition.left, copy=False),
237                        exp.not_(condition.right, copy=False),
238                        copy=False,
239                    )
240                )
241            if is_null(condition):
242                return exp.null()
243        if always_true(this):
244            return exp.false()
245        if is_false(this):
246            return exp.true()
247        if isinstance(this, exp.Not):
248            # double negation
249            # NOT NOT x -> x
250            return this.this
251    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
254def flatten(expression):
255    """
256    A AND (B AND C) -> A AND B AND C
257    A OR (B OR C) -> A OR B OR C
258    """
259    if isinstance(expression, exp.Connector):
260        for node in expression.args.values():
261            child = node.unnest()
262            if isinstance(child, expression.__class__):
263                node.replace(child)
264    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
267def simplify_connectors(expression, root=True):
268    def _simplify_connectors(expression, left, right):
269        if isinstance(expression, exp.And):
270            if is_false(left) or is_false(right):
271                return exp.false()
272            if is_zero(left) or is_zero(right):
273                return exp.false()
274            if is_null(left) or is_null(right):
275                return exp.null()
276            if always_true(left) and always_true(right):
277                return exp.true()
278            if always_true(left):
279                return right
280            if always_true(right):
281                return left
282            return _simplify_comparison(expression, left, right)
283        elif isinstance(expression, exp.Or):
284            if always_true(left) or always_true(right):
285                return exp.true()
286            if (
287                (is_null(left) and is_null(right))
288                or (is_null(left) and always_false(right))
289                or (always_false(left) and is_null(right))
290            ):
291                return exp.null()
292            if is_false(left):
293                return right
294            if is_false(right):
295                return left
296            return _simplify_comparison(expression, left, right, or_=True)
297        elif isinstance(expression, exp.Xor):
298            if left == right:
299                return exp.false()
300
301    if isinstance(expression, exp.Connector):
302        return _flat_simplify(expression, _simplify_connectors, root)
303    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
390def remove_complements(expression, root=True):
391    """
392    Removing complements.
393
394    A AND NOT A -> FALSE
395    A OR NOT A -> TRUE
396    """
397    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
398        ops = set(expression.flatten())
399        for op in ops:
400            if isinstance(op, exp.Not) and op.this in ops:
401                return exp.false() if isinstance(expression, exp.And) else exp.true()
402
403    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
406def uniq_sort(expression, root=True):
407    """
408    Uniq and sort a connector.
409
410    C AND A AND B AND B -> A AND B AND C
411    """
412    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
413        flattened = tuple(expression.flatten())
414
415        if isinstance(expression, exp.Xor):
416            result_func = exp.xor
417            # Do not deduplicate XOR as A XOR A != A if A == True
418            deduped = None
419            arr = tuple((gen(e), e) for e in flattened)
420        else:
421            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
422            deduped = {gen(e): e for e in flattened}
423            arr = tuple(deduped.items())
424
425        # check if the operands are already sorted, if not sort them
426        # A AND C AND B -> A AND B AND C
427        for i, (sql, e) in enumerate(arr[1:]):
428            if sql < arr[i][0]:
429                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
430                break
431        else:
432            # we didn't have to sort but maybe we need to dedup
433            if deduped and len(deduped) < len(flattened):
434                expression = result_func(*deduped.values(), copy=False)
435
436    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
439def absorb_and_eliminate(expression, root=True):
440    """
441    absorption:
442        A AND (A OR B) -> A
443        A OR (A AND B) -> A
444        A AND (NOT A OR B) -> A AND B
445        A OR (NOT A AND B) -> A OR B
446    elimination:
447        (A AND B) OR (A AND NOT B) -> A
448        (A OR B) AND (A OR NOT B) -> A
449    """
450    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
451        kind = exp.Or if isinstance(expression, exp.And) else exp.And
452
453        ops = tuple(expression.flatten())
454
455        # Initialize lookup tables:
456        # Set of all operands, used to find complements for absorption.
457        op_set = set()
458        # Sub-operands, used to find subsets for absorption.
459        subops = defaultdict(list)
460        # Pairs of complements, used for elimination.
461        pairs = defaultdict(list)
462
463        # Populate the lookup tables
464        for op in ops:
465            op_set.add(op)
466
467            if not isinstance(op, kind):
468                # In cases like: A OR (A AND B)
469                # Subop will be: ^
470                subops[op].append({op})
471                continue
472
473            # In cases like: (A AND B) OR (A AND B AND C)
474            # Subops will be: ^     ^
475            subset = set(op.flatten())
476            for i in subset:
477                subops[i].append(subset)
478
479            a, b = op.unnest_operands()
480            if isinstance(a, exp.Not):
481                pairs[frozenset((a.this, b))].append((op, b))
482            if isinstance(b, exp.Not):
483                pairs[frozenset((a, b.this))].append((op, a))
484
485        for op in ops:
486            if not isinstance(op, kind):
487                continue
488
489            a, b = op.unnest_operands()
490
491            # Absorb
492            if isinstance(a, exp.Not) and a.this in op_set:
493                a.replace(exp.true() if kind == exp.And else exp.false())
494                continue
495            if isinstance(b, exp.Not) and b.this in op_set:
496                b.replace(exp.true() if kind == exp.And else exp.false())
497                continue
498            superset = set(op.flatten())
499            if any(any(subset < superset for subset in subops[i]) for i in superset):
500                op.replace(exp.false() if kind == exp.And else exp.true())
501                continue
502
503            # Eliminate
504            for other, complement in pairs[frozenset((a, b))]:
505                op.replace(complement)
506                other.replace(complement)
507
508    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
511def propagate_constants(expression, root=True):
512    """
513    Propagate constants for conjunctions in DNF:
514
515    SELECT * FROM t WHERE a = b AND b = 5 becomes
516    SELECT * FROM t WHERE a = 5 AND b = 5
517
518    Reference: https://www.sqlite.org/optoverview.html
519    """
520
521    if (
522        isinstance(expression, exp.And)
523        and (root or not expression.same_parent)
524        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
525    ):
526        constant_mapping = {}
527        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
528            if isinstance(expr, exp.EQ):
529                l, r = expr.left, expr.right
530
531                # TODO: create a helper that can be used to detect nested literal expressions such
532                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
533                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
534                    constant_mapping[l] = (id(l), r)
535
536        if constant_mapping:
537            for column in find_all_in_scope(expression, exp.Column):
538                parent = column.parent
539                column_id, constant = constant_mapping.get(column) or (None, None)
540                if (
541                    column_id is not None
542                    and id(column) != column_id
543                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
544                ):
545                    column.replace(constant.copy())
546
547    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
160        def wrapped(expression, *args, **kwargs):
161            try:
162                return func(expression, *args, **kwargs)
163            except exceptions:
164                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
622def simplify_literals(expression, root=True):
623    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
624        return _flat_simplify(expression, _simplify_binary, root)
625
626    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
627        return expression.this.this
628
629    if type(expression) in INVERSE_DATE_OPS:
630        return _simplify_binary(expression, expression.this, expression.interval()) or expression
631
632    return expression
def simplify_parens( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
733def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
734    if not isinstance(expression, exp.Paren):
735        return expression
736
737    this = expression.this
738    parent = expression.parent
739    parent_is_predicate = isinstance(parent, exp.Predicate)
740
741    if isinstance(this, exp.Select):
742        return expression
743
744    if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
745        return expression
746
747    # Handle risingwave struct columns
748    # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
749    if (
750        dialect == "risingwave"
751        and isinstance(parent, exp.Dot)
752        and (isinstance(parent.right, (exp.Identifier, exp.Star)))
753    ):
754        return expression
755
756    if (
757        not isinstance(parent, (exp.Condition, exp.Binary))
758        or isinstance(parent, exp.Paren)
759        or (
760            not isinstance(this, exp.Binary)
761            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
762        )
763        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
764        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
765        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
766        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
767    ):
768        return this
769
770    return expression
def simplify_coalesce( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
781def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
782    # COALESCE(x) -> x
783    if (
784        isinstance(expression, exp.Coalesce)
785        and (not expression.expressions or _is_nonnull_constant(expression.this))
786        # COALESCE is also used as a Spark partitioning hint
787        and not isinstance(expression.parent, exp.Hint)
788    ):
789        return expression.this
790
791    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
792    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
793    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
794    if dialect == "redshift":
795        return expression
796
797    if not isinstance(expression, COMPARISONS):
798        return expression
799
800    if isinstance(expression.left, exp.Coalesce):
801        coalesce = expression.left
802        other = expression.right
803    elif isinstance(expression.right, exp.Coalesce):
804        coalesce = expression.right
805        other = expression.left
806    else:
807        return expression
808
809    # This transformation is valid for non-constants,
810    # but it really only does anything if they are both constants.
811    if not _is_constant(other):
812        return expression
813
814    # Find the first constant arg
815    for arg_index, arg in enumerate(coalesce.expressions):
816        if _is_constant(arg):
817            break
818    else:
819        return expression
820
821    coalesce.set("expressions", coalesce.expressions[:arg_index])
822
823    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
824    # since we already remove COALESCE at the top of this function.
825    coalesce = coalesce if coalesce.expressions else coalesce.this
826
827    # This expression is more complex than when we started, but it will get simplified further
828    return exp.paren(
829        exp.or_(
830            exp.and_(
831                coalesce.is_(exp.null()).not_(copy=False),
832                expression.copy(),
833                copy=False,
834            ),
835            exp.and_(
836                coalesce.is_(exp.null()),
837                type(expression)(this=arg.copy(), expression=other.copy()),
838                copy=False,
839            ),
840            copy=False,
841        )
842    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
848def simplify_concat(expression):
849    """Reduces all groups that contain string literals by concatenating them."""
850    if not isinstance(expression, CONCATS) or (
851        # We can't reduce a CONCAT_WS call if we don't statically know the separator
852        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
853    ):
854        return expression
855
856    if isinstance(expression, exp.ConcatWs):
857        sep_expr, *expressions = expression.expressions
858        sep = sep_expr.name
859        concat_type = exp.ConcatWs
860        args = {}
861    else:
862        expressions = expression.expressions
863        sep = ""
864        concat_type = exp.Concat
865        args = {
866            "safe": expression.args.get("safe"),
867            "coalesce": expression.args.get("coalesce"),
868        }
869
870    new_args = []
871    for is_string_group, group in itertools.groupby(
872        expressions or expression.flatten(), lambda e: e.is_string
873    ):
874        if is_string_group:
875            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
876        else:
877            new_args.extend(group)
878
879    if len(new_args) == 1 and new_args[0].is_string:
880        return new_args[0]
881
882    if concat_type is exp.ConcatWs:
883        new_args = [sep_expr] + new_args
884    elif isinstance(expression, exp.DPipe):
885        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
886
887    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
890def simplify_conditionals(expression):
891    """Simplifies expressions like IF, CASE if their condition is statically known."""
892    if isinstance(expression, exp.Case):
893        this = expression.this
894        for case in expression.args["ifs"]:
895            cond = case.this
896            if this:
897                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
898                cond = cond.replace(this.pop().eq(cond))
899
900            if always_true(cond):
901                return case.args["true"]
902
903            if always_false(cond):
904                case.pop()
905                if not expression.args["ifs"]:
906                    return expression.args.get("default") or exp.null()
907    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
908        if always_true(expression.this):
909            return expression.args["true"]
910        if always_false(expression.this):
911            return expression.args.get("false") or exp.null()
912
913    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
916def simplify_startswith(expression: exp.Expression) -> exp.Expression:
917    """
918    Reduces a prefix check to either TRUE or FALSE if both the string and the
919    prefix are statically known.
920
921    Example:
922        >>> from sqlglot import parse_one
923        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
924        'TRUE'
925    """
926    if (
927        isinstance(expression, exp.StartsWith)
928        and expression.this.is_string
929        and expression.expression.is_string
930    ):
931        return exp.convert(expression.name.startswith(expression.expression.name))
932
933    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.GTE'>}
def simplify_datetrunc(expression, *args, **kwargs):
160        def wrapped(expression, *args, **kwargs):
161            try:
162                return func(expression, *args, **kwargs)
163            except exceptions:
164                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1082def sort_comparison(expression: exp.Expression) -> exp.Expression:
1083    if expression.__class__ in COMPLEMENT_COMPARISONS:
1084        l, r = expression.this, expression.expression
1085        l_column = isinstance(l, exp.Column)
1086        r_column = isinstance(r, exp.Column)
1087        l_const = _is_constant(l)
1088        r_const = _is_constant(r)
1089
1090        if (
1091            (l_column and not r_column)
1092            or (r_const and not l_const)
1093            or isinstance(r, exp.SubqueryPredicate)
1094        ):
1095            return expression
1096        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1097            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1098                this=r, expression=l
1099            )
1100    return expression
JOINS = {('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', '')}
def remove_where_true(expression):
1114def remove_where_true(expression):
1115    for where in expression.find_all(exp.Where):
1116        if always_true(where.this):
1117            where.pop()
1118    for join in expression.find_all(exp.Join):
1119        if (
1120            always_true(join.args.get("on"))
1121            and not join.args.get("using")
1122            and not join.args.get("method")
1123            and (join.side, join.kind) in JOINS
1124        ):
1125            join.args["on"].pop()
1126            join.set("side", None)
1127            join.set("kind", "CROSS")
def always_true(expression):
1130def always_true(expression):
1131    return (isinstance(expression, exp.Boolean) and expression.this) or (
1132        isinstance(expression, exp.Literal) and not is_zero(expression)
1133    )
def always_false(expression):
1136def always_false(expression):
1137    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
1140def is_zero(expression):
1141    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):
1144def is_complement(a, b):
1145    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1148def is_false(a: exp.Expression) -> bool:
1149    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1152def is_null(a: exp.Expression) -> bool:
1153    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1156def eval_boolean(expression, a, b):
1157    if isinstance(expression, (exp.EQ, exp.Is)):
1158        return boolean_literal(a == b)
1159    if isinstance(expression, exp.NEQ):
1160        return boolean_literal(a != b)
1161    if isinstance(expression, exp.GT):
1162        return boolean_literal(a > b)
1163    if isinstance(expression, exp.GTE):
1164        return boolean_literal(a >= b)
1165    if isinstance(expression, exp.LT):
1166        return boolean_literal(a < b)
1167    if isinstance(expression, exp.LTE):
1168        return boolean_literal(a <= b)
1169    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1172def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1173    if isinstance(value, datetime.datetime):
1174        return value.date()
1175    if isinstance(value, datetime.date):
1176        return value
1177    try:
1178        return datetime.datetime.fromisoformat(value).date()
1179    except ValueError:
1180        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1183def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1184    if isinstance(value, datetime.datetime):
1185        return value
1186    if isinstance(value, datetime.date):
1187        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1188    try:
1189        return datetime.datetime.fromisoformat(value)
1190    except ValueError:
1191        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1194def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1195    if not value:
1196        return None
1197    if to.is_type(exp.DataType.Type.DATE):
1198        return cast_as_date(value)
1199    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1200        return cast_as_datetime(value)
1201    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1204def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1205    if isinstance(cast, exp.Cast):
1206        to = cast.to
1207    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1208        to = exp.DataType.build(exp.DataType.Type.DATE)
1209    else:
1210        return None
1211
1212    if isinstance(cast.this, exp.Literal):
1213        value: t.Any = cast.this.name
1214    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1215        value = extract_date(cast.this)
1216    else:
1217        return None
1218    return cast_value(value, to)
def extract_interval(expression):
1225def extract_interval(expression):
1226    try:
1227        n = int(expression.this.to_py())
1228        unit = expression.text("unit").lower()
1229        return interval(unit, n)
1230    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1231        return None
def extract_type(*expressions):
1234def extract_type(*expressions):
1235    target_type = None
1236    for expression in expressions:
1237        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1238        if target_type:
1239            break
1240
1241    return target_type
def date_literal(date, target_type=None):
1244def date_literal(date, target_type=None):
1245    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1246        target_type = (
1247            exp.DataType.Type.DATETIME
1248            if isinstance(date, datetime.datetime)
1249            else exp.DataType.Type.DATE
1250        )
1251
1252    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1255def interval(unit: str, n: int = 1):
1256    from dateutil.relativedelta import relativedelta
1257
1258    if unit == "year":
1259        return relativedelta(years=1 * n)
1260    if unit == "quarter":
1261        return relativedelta(months=3 * n)
1262    if unit == "month":
1263        return relativedelta(months=1 * n)
1264    if unit == "week":
1265        return relativedelta(weeks=1 * n)
1266    if unit == "day":
1267        return relativedelta(days=1 * n)
1268    if unit == "hour":
1269        return relativedelta(hours=1 * n)
1270    if unit == "minute":
1271        return relativedelta(minutes=1 * n)
1272    if unit == "second":
1273        return relativedelta(seconds=1 * n)
1274
1275    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1278def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1279    if unit == "year":
1280        return d.replace(month=1, day=1)
1281    if unit == "quarter":
1282        if d.month <= 3:
1283            return d.replace(month=1, day=1)
1284        elif d.month <= 6:
1285            return d.replace(month=4, day=1)
1286        elif d.month <= 9:
1287            return d.replace(month=7, day=1)
1288        else:
1289            return d.replace(month=10, day=1)
1290    if unit == "month":
1291        return d.replace(month=d.month, day=1)
1292    if unit == "week":
1293        # Assuming week starts on Monday (0) and ends on Sunday (6)
1294        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1295    if unit == "day":
1296        return d
1297
1298    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1301def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1302    floor = date_floor(d, unit, dialect)
1303
1304    if floor == d:
1305        return d
1306
1307    return floor + interval(unit)
def boolean_literal(condition):
1310def boolean_literal(condition):
1311    return exp.true() if condition else exp.false()
def gen(expression: Any, comments: bool = False) -> str:
1340def gen(expression: t.Any, comments: bool = False) -> str:
1341    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1342
1343    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1344    generator is expensive so we have a bare minimum sql generator here.
1345
1346    Args:
1347        expression: the expression to convert into a SQL string.
1348        comments: whether to include the expression's comments.
1349    """
1350    return Gen().gen(expression, comments=comments)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

Arguments:
  • expression: the expression to convert into a SQL string.
  • comments: whether to include the expression's comments.
class Gen:
1353class Gen:
1354    def __init__(self):
1355        self.stack = []
1356        self.sqls = []
1357
1358    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1359        self.stack = [expression]
1360        self.sqls.clear()
1361
1362        while self.stack:
1363            node = self.stack.pop()
1364
1365            if isinstance(node, exp.Expression):
1366                if comments and node.comments:
1367                    self.stack.append(f" /*{','.join(node.comments)}*/")
1368
1369                exp_handler_name = f"{node.key}_sql"
1370
1371                if hasattr(self, exp_handler_name):
1372                    getattr(self, exp_handler_name)(node)
1373                elif isinstance(node, exp.Func):
1374                    self._function(node)
1375                else:
1376                    key = node.key.upper()
1377                    self.stack.append(f"{key} " if self._args(node) else key)
1378            elif type(node) is list:
1379                for n in reversed(node):
1380                    if n is not None:
1381                        self.stack.extend((n, ","))
1382                if node:
1383                    self.stack.pop()
1384            else:
1385                if node is not None:
1386                    self.sqls.append(str(node))
1387
1388        return "".join(self.sqls)
1389
1390    def add_sql(self, e: exp.Add) -> None:
1391        self._binary(e, " + ")
1392
1393    def alias_sql(self, e: exp.Alias) -> None:
1394        self.stack.extend(
1395            (
1396                e.args.get("alias"),
1397                " AS ",
1398                e.args.get("this"),
1399            )
1400        )
1401
1402    def and_sql(self, e: exp.And) -> None:
1403        self._binary(e, " AND ")
1404
1405    def anonymous_sql(self, e: exp.Anonymous) -> None:
1406        this = e.this
1407        if isinstance(this, str):
1408            name = this.upper()
1409        elif isinstance(this, exp.Identifier):
1410            name = this.this
1411            name = f'"{name}"' if this.quoted else name.upper()
1412        else:
1413            raise ValueError(
1414                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1415            )
1416
1417        self.stack.extend(
1418            (
1419                ")",
1420                e.expressions,
1421                "(",
1422                name,
1423            )
1424        )
1425
1426    def between_sql(self, e: exp.Between) -> None:
1427        self.stack.extend(
1428            (
1429                e.args.get("high"),
1430                " AND ",
1431                e.args.get("low"),
1432                " BETWEEN ",
1433                e.this,
1434            )
1435        )
1436
1437    def boolean_sql(self, e: exp.Boolean) -> None:
1438        self.stack.append("TRUE" if e.this else "FALSE")
1439
1440    def bracket_sql(self, e: exp.Bracket) -> None:
1441        self.stack.extend(
1442            (
1443                "]",
1444                e.expressions,
1445                "[",
1446                e.this,
1447            )
1448        )
1449
1450    def column_sql(self, e: exp.Column) -> None:
1451        for p in reversed(e.parts):
1452            self.stack.extend((p, "."))
1453        self.stack.pop()
1454
1455    def datatype_sql(self, e: exp.DataType) -> None:
1456        self._args(e, 1)
1457        self.stack.append(f"{e.this.name} ")
1458
1459    def div_sql(self, e: exp.Div) -> None:
1460        self._binary(e, " / ")
1461
1462    def dot_sql(self, e: exp.Dot) -> None:
1463        self._binary(e, ".")
1464
1465    def eq_sql(self, e: exp.EQ) -> None:
1466        self._binary(e, " = ")
1467
1468    def from_sql(self, e: exp.From) -> None:
1469        self.stack.extend((e.this, "FROM "))
1470
1471    def gt_sql(self, e: exp.GT) -> None:
1472        self._binary(e, " > ")
1473
1474    def gte_sql(self, e: exp.GTE) -> None:
1475        self._binary(e, " >= ")
1476
1477    def identifier_sql(self, e: exp.Identifier) -> None:
1478        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1479
1480    def ilike_sql(self, e: exp.ILike) -> None:
1481        self._binary(e, " ILIKE ")
1482
1483    def in_sql(self, e: exp.In) -> None:
1484        self.stack.append(")")
1485        self._args(e, 1)
1486        self.stack.extend(
1487            (
1488                "(",
1489                " IN ",
1490                e.this,
1491            )
1492        )
1493
1494    def intdiv_sql(self, e: exp.IntDiv) -> None:
1495        self._binary(e, " DIV ")
1496
1497    def is_sql(self, e: exp.Is) -> None:
1498        self._binary(e, " IS ")
1499
1500    def like_sql(self, e: exp.Like) -> None:
1501        self._binary(e, " Like ")
1502
1503    def literal_sql(self, e: exp.Literal) -> None:
1504        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1505
1506    def lt_sql(self, e: exp.LT) -> None:
1507        self._binary(e, " < ")
1508
1509    def lte_sql(self, e: exp.LTE) -> None:
1510        self._binary(e, " <= ")
1511
1512    def mod_sql(self, e: exp.Mod) -> None:
1513        self._binary(e, " % ")
1514
1515    def mul_sql(self, e: exp.Mul) -> None:
1516        self._binary(e, " * ")
1517
1518    def neg_sql(self, e: exp.Neg) -> None:
1519        self._unary(e, "-")
1520
1521    def neq_sql(self, e: exp.NEQ) -> None:
1522        self._binary(e, " <> ")
1523
1524    def not_sql(self, e: exp.Not) -> None:
1525        self._unary(e, "NOT ")
1526
1527    def null_sql(self, e: exp.Null) -> None:
1528        self.stack.append("NULL")
1529
1530    def or_sql(self, e: exp.Or) -> None:
1531        self._binary(e, " OR ")
1532
1533    def paren_sql(self, e: exp.Paren) -> None:
1534        self.stack.extend(
1535            (
1536                ")",
1537                e.this,
1538                "(",
1539            )
1540        )
1541
1542    def sub_sql(self, e: exp.Sub) -> None:
1543        self._binary(e, " - ")
1544
1545    def subquery_sql(self, e: exp.Subquery) -> None:
1546        self._args(e, 2)
1547        alias = e.args.get("alias")
1548        if alias:
1549            self.stack.append(alias)
1550        self.stack.extend((")", e.this, "("))
1551
1552    def table_sql(self, e: exp.Table) -> None:
1553        self._args(e, 4)
1554        alias = e.args.get("alias")
1555        if alias:
1556            self.stack.append(alias)
1557        for p in reversed(e.parts):
1558            self.stack.extend((p, "."))
1559        self.stack.pop()
1560
1561    def tablealias_sql(self, e: exp.TableAlias) -> None:
1562        columns = e.columns
1563
1564        if columns:
1565            self.stack.extend((")", columns, "("))
1566
1567        self.stack.extend((e.this, " AS "))
1568
1569    def var_sql(self, e: exp.Var) -> None:
1570        self.stack.append(e.this)
1571
1572    def _binary(self, e: exp.Binary, op: str) -> None:
1573        self.stack.extend((e.expression, op, e.this))
1574
1575    def _unary(self, e: exp.Unary, op: str) -> None:
1576        self.stack.extend((e.this, op))
1577
1578    def _function(self, e: exp.Func) -> None:
1579        self.stack.extend(
1580            (
1581                ")",
1582                list(e.args.values()),
1583                "(",
1584                e.sql_name(),
1585            )
1586        )
1587
1588    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1589        kvs = []
1590        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1591
1592        for k in arg_types or arg_types:
1593            v = node.args.get(k)
1594
1595            if v is not None:
1596                kvs.append([f":{k}", v])
1597        if kvs:
1598            self.stack.append(kvs)
1599            return True
1600        return False
stack
sqls
def gen( self, expression: sqlglot.expressions.Expression, comments: bool = False) -> str:
1358    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1359        self.stack = [expression]
1360        self.sqls.clear()
1361
1362        while self.stack:
1363            node = self.stack.pop()
1364
1365            if isinstance(node, exp.Expression):
1366                if comments and node.comments:
1367                    self.stack.append(f" /*{','.join(node.comments)}*/")
1368
1369                exp_handler_name = f"{node.key}_sql"
1370
1371                if hasattr(self, exp_handler_name):
1372                    getattr(self, exp_handler_name)(node)
1373                elif isinstance(node, exp.Func):
1374                    self._function(node)
1375                else:
1376                    key = node.key.upper()
1377                    self.stack.append(f"{key} " if self._args(node) else key)
1378            elif type(node) is list:
1379                for n in reversed(node):
1380                    if n is not None:
1381                        self.stack.extend((n, ","))
1382                if node:
1383                    self.stack.pop()
1384            else:
1385                if node is not None:
1386                    self.sqls.append(str(node))
1387
1388        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1390    def add_sql(self, e: exp.Add) -> None:
1391        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1393    def alias_sql(self, e: exp.Alias) -> None:
1394        self.stack.extend(
1395            (
1396                e.args.get("alias"),
1397                " AS ",
1398                e.args.get("this"),
1399            )
1400        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1402    def and_sql(self, e: exp.And) -> None:
1403        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1405    def anonymous_sql(self, e: exp.Anonymous) -> None:
1406        this = e.this
1407        if isinstance(this, str):
1408            name = this.upper()
1409        elif isinstance(this, exp.Identifier):
1410            name = this.this
1411            name = f'"{name}"' if this.quoted else name.upper()
1412        else:
1413            raise ValueError(
1414                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1415            )
1416
1417        self.stack.extend(
1418            (
1419                ")",
1420                e.expressions,
1421                "(",
1422                name,
1423            )
1424        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1426    def between_sql(self, e: exp.Between) -> None:
1427        self.stack.extend(
1428            (
1429                e.args.get("high"),
1430                " AND ",
1431                e.args.get("low"),
1432                " BETWEEN ",
1433                e.this,
1434            )
1435        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1437    def boolean_sql(self, e: exp.Boolean) -> None:
1438        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1440    def bracket_sql(self, e: exp.Bracket) -> None:
1441        self.stack.extend(
1442            (
1443                "]",
1444                e.expressions,
1445                "[",
1446                e.this,
1447            )
1448        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1450    def column_sql(self, e: exp.Column) -> None:
1451        for p in reversed(e.parts):
1452            self.stack.extend((p, "."))
1453        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1455    def datatype_sql(self, e: exp.DataType) -> None:
1456        self._args(e, 1)
1457        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1459    def div_sql(self, e: exp.Div) -> None:
1460        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1462    def dot_sql(self, e: exp.Dot) -> None:
1463        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1465    def eq_sql(self, e: exp.EQ) -> None:
1466        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1468    def from_sql(self, e: exp.From) -> None:
1469        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1471    def gt_sql(self, e: exp.GT) -> None:
1472        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1474    def gte_sql(self, e: exp.GTE) -> None:
1475        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1477    def identifier_sql(self, e: exp.Identifier) -> None:
1478        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1480    def ilike_sql(self, e: exp.ILike) -> None:
1481        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1483    def in_sql(self, e: exp.In) -> None:
1484        self.stack.append(")")
1485        self._args(e, 1)
1486        self.stack.extend(
1487            (
1488                "(",
1489                " IN ",
1490                e.this,
1491            )
1492        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1494    def intdiv_sql(self, e: exp.IntDiv) -> None:
1495        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1497    def is_sql(self, e: exp.Is) -> None:
1498        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1500    def like_sql(self, e: exp.Like) -> None:
1501        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1503    def literal_sql(self, e: exp.Literal) -> None:
1504        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1506    def lt_sql(self, e: exp.LT) -> None:
1507        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1509    def lte_sql(self, e: exp.LTE) -> None:
1510        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1512    def mod_sql(self, e: exp.Mod) -> None:
1513        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1515    def mul_sql(self, e: exp.Mul) -> None:
1516        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1518    def neg_sql(self, e: exp.Neg) -> None:
1519        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1521    def neq_sql(self, e: exp.NEQ) -> None:
1522        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1524    def not_sql(self, e: exp.Not) -> None:
1525        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1527    def null_sql(self, e: exp.Null) -> None:
1528        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1530    def or_sql(self, e: exp.Or) -> None:
1531        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1533    def paren_sql(self, e: exp.Paren) -> None:
1534        self.stack.extend(
1535            (
1536                ")",
1537                e.this,
1538                "(",
1539            )
1540        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1542    def sub_sql(self, e: exp.Sub) -> None:
1543        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1545    def subquery_sql(self, e: exp.Subquery) -> None:
1546        self._args(e, 2)
1547        alias = e.args.get("alias")
1548        if alias:
1549            self.stack.append(alias)
1550        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1552    def table_sql(self, e: exp.Table) -> None:
1553        self._args(e, 4)
1554        alias = e.args.get("alias")
1555        if alias:
1556            self.stack.append(alias)
1557        for p in reversed(e.parts):
1558            self.stack.extend((p, "."))
1559        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1561    def tablealias_sql(self, e: exp.TableAlias) -> None:
1562        columns = e.columns
1563
1564        if columns:
1565            self.stack.extend((")", columns, "("))
1566
1567        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1569    def var_sql(self, e: exp.Var) -> None:
1570        self.stack.append(e.this)