Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23logger = logging.getLogger("sqlglot")
  24
  25# Final means that an expression should not be simplified
  26FINAL = "final"
  27
  28# Value ranges for byte-sized signed/unsigned integers
  29TINYINT_MIN = -128
  30TINYINT_MAX = 127
  31UTINYINT_MIN = 0
  32UTINYINT_MAX = 255
  33
  34
  35class UnsupportedUnit(Exception):
  36    pass
  37
  38
  39def simplify(
  40    expression: exp.Expression,
  41    constant_propagation: bool = False,
  42    dialect: DialectType = None,
  43    max_depth: t.Optional[int] = None,
  44):
  45    """
  46    Rewrite sqlglot AST to simplify expressions.
  47
  48    Example:
  49        >>> import sqlglot
  50        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  51        >>> simplify(expression).sql()
  52        'TRUE'
  53
  54    Args:
  55        expression: expression to simplify
  56        constant_propagation: whether the constant propagation rule should be used
  57        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
  58    Returns:
  59        sqlglot.Expression: simplified expression
  60    """
  61
  62    dialect = Dialect.get_or_raise(dialect)
  63
  64    def _simplify(expression, root=True):
  65        if (
  66            max_depth
  67            and isinstance(expression, exp.Connector)
  68            and not isinstance(expression.parent, exp.Connector)
  69        ):
  70            depth = connector_depth(expression)
  71            if depth > max_depth:
  72                logger.info(
  73                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
  74                )
  75                return expression
  76
  77        if expression.meta.get(FINAL):
  78            return expression
  79
  80        # group by expressions cannot be simplified, for example
  81        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  82        # the projection must exactly match the group by key
  83        group = expression.args.get("group")
  84
  85        if group and hasattr(expression, "selects"):
  86            groups = set(group.expressions)
  87            group.meta[FINAL] = True
  88
  89            for e in expression.selects:
  90                for node in e.walk():
  91                    if node in groups:
  92                        e.meta[FINAL] = True
  93                        break
  94
  95            having = expression.args.get("having")
  96            if having:
  97                for node in having.walk():
  98                    if node in groups:
  99                        having.meta[FINAL] = True
 100                        break
 101
 102        # Pre-order transformations
 103        node = expression
 104        node = rewrite_between(node)
 105        node = uniq_sort(node, root)
 106        node = absorb_and_eliminate(node, root)
 107        node = simplify_concat(node)
 108        node = simplify_conditionals(node)
 109
 110        if constant_propagation:
 111            node = propagate_constants(node, root)
 112
 113        exp.replace_children(node, lambda e: _simplify(e, False))
 114
 115        # Post-order transformations
 116        node = simplify_not(node)
 117        node = flatten(node)
 118        node = simplify_connectors(node, root)
 119        node = remove_complements(node, root)
 120        node = simplify_coalesce(node)
 121        node.parent = expression.parent
 122        node = simplify_literals(node, root)
 123        node = simplify_equality(node)
 124        node = simplify_parens(node)
 125        node = simplify_datetrunc(node, dialect)
 126        node = sort_comparison(node)
 127        node = simplify_startswith(node)
 128
 129        if root:
 130            expression.replace(node)
 131        return node
 132
 133    expression = while_changing(expression, _simplify)
 134    remove_where_true(expression)
 135    return expression
 136
 137
 138def connector_depth(expression: exp.Expression) -> int:
 139    """
 140    Determine the maximum depth of a tree of Connectors.
 141
 142    For example:
 143        >>> from sqlglot import parse_one
 144        >>> connector_depth(parse_one("a AND b AND c AND d"))
 145        3
 146    """
 147    stack = deque([(expression, 0)])
 148    max_depth = 0
 149
 150    while stack:
 151        expression, depth = stack.pop()
 152
 153        if not isinstance(expression, exp.Connector):
 154            continue
 155
 156        depth += 1
 157        max_depth = max(depth, max_depth)
 158
 159        stack.append((expression.left, depth))
 160        stack.append((expression.right, depth))
 161
 162    return max_depth
 163
 164
 165def catch(*exceptions):
 166    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 167
 168    def decorator(func):
 169        def wrapped(expression, *args, **kwargs):
 170            try:
 171                return func(expression, *args, **kwargs)
 172            except exceptions:
 173                return expression
 174
 175        return wrapped
 176
 177    return decorator
 178
 179
 180def rewrite_between(expression: exp.Expression) -> exp.Expression:
 181    """Rewrite x between y and z to x >= y AND x <= z.
 182
 183    This is done because comparison simplification is only done on lt/lte/gt/gte.
 184    """
 185    if isinstance(expression, exp.Between):
 186        negate = isinstance(expression.parent, exp.Not)
 187
 188        expression = exp.and_(
 189            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 190            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 191            copy=False,
 192        )
 193
 194        if negate:
 195            expression = exp.paren(expression, copy=False)
 196
 197    return expression
 198
 199
 200COMPLEMENT_COMPARISONS = {
 201    exp.LT: exp.GTE,
 202    exp.GT: exp.LTE,
 203    exp.LTE: exp.GT,
 204    exp.GTE: exp.LT,
 205    exp.EQ: exp.NEQ,
 206    exp.NEQ: exp.EQ,
 207}
 208
 209
 210def simplify_not(expression):
 211    """
 212    Demorgan's Law
 213    NOT (x OR y) -> NOT x AND NOT y
 214    NOT (x AND y) -> NOT x OR NOT y
 215    """
 216    if isinstance(expression, exp.Not):
 217        this = expression.this
 218        if is_null(this):
 219            return exp.null()
 220        if this.__class__ in COMPLEMENT_COMPARISONS:
 221            return COMPLEMENT_COMPARISONS[this.__class__](
 222                this=this.this, expression=this.expression
 223            )
 224        if isinstance(this, exp.Paren):
 225            condition = this.unnest()
 226            if isinstance(condition, exp.And):
 227                return exp.paren(
 228                    exp.or_(
 229                        exp.not_(condition.left, copy=False),
 230                        exp.not_(condition.right, copy=False),
 231                        copy=False,
 232                    )
 233                )
 234            if isinstance(condition, exp.Or):
 235                return exp.paren(
 236                    exp.and_(
 237                        exp.not_(condition.left, copy=False),
 238                        exp.not_(condition.right, copy=False),
 239                        copy=False,
 240                    )
 241                )
 242            if is_null(condition):
 243                return exp.null()
 244        if always_true(this):
 245            return exp.false()
 246        if is_false(this):
 247            return exp.true()
 248        if isinstance(this, exp.Not):
 249            # double negation
 250            # NOT NOT x -> x
 251            return this.this
 252    return expression
 253
 254
 255def flatten(expression):
 256    """
 257    A AND (B AND C) -> A AND B AND C
 258    A OR (B OR C) -> A OR B OR C
 259    """
 260    if isinstance(expression, exp.Connector):
 261        for node in expression.args.values():
 262            child = node.unnest()
 263            if isinstance(child, expression.__class__):
 264                node.replace(child)
 265    return expression
 266
 267
 268def simplify_connectors(expression, root=True):
 269    def _simplify_connectors(expression, left, right):
 270        if left == right:
 271            if isinstance(expression, exp.Xor):
 272                return exp.false()
 273            return left
 274        if isinstance(expression, exp.And):
 275            if is_false(left) or is_false(right):
 276                return exp.false()
 277            if is_null(left) or is_null(right):
 278                return exp.null()
 279            if always_true(left) and always_true(right):
 280                return exp.true()
 281            if always_true(left):
 282                return right
 283            if always_true(right):
 284                return left
 285            return _simplify_comparison(expression, left, right)
 286        elif isinstance(expression, exp.Or):
 287            if always_true(left) or always_true(right):
 288                return exp.true()
 289            if is_false(left) and is_false(right):
 290                return exp.false()
 291            if (
 292                (is_null(left) and is_null(right))
 293                or (is_null(left) and is_false(right))
 294                or (is_false(left) and is_null(right))
 295            ):
 296                return exp.null()
 297            if is_false(left):
 298                return right
 299            if is_false(right):
 300                return left
 301            return _simplify_comparison(expression, left, right, or_=True)
 302
 303    if isinstance(expression, exp.Connector):
 304        return _flat_simplify(expression, _simplify_connectors, root)
 305    return expression
 306
 307
 308LT_LTE = (exp.LT, exp.LTE)
 309GT_GTE = (exp.GT, exp.GTE)
 310
 311COMPARISONS = (
 312    *LT_LTE,
 313    *GT_GTE,
 314    exp.EQ,
 315    exp.NEQ,
 316    exp.Is,
 317)
 318
 319INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 320    exp.LT: exp.GT,
 321    exp.GT: exp.LT,
 322    exp.LTE: exp.GTE,
 323    exp.GTE: exp.LTE,
 324}
 325
 326NONDETERMINISTIC = (exp.Rand, exp.Randn)
 327AND_OR = (exp.And, exp.Or)
 328
 329
 330def _simplify_comparison(expression, left, right, or_=False):
 331    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 332        ll, lr = left.args.values()
 333        rl, rr = right.args.values()
 334
 335        largs = {ll, lr}
 336        rargs = {rl, rr}
 337
 338        matching = largs & rargs
 339        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 340
 341        if matching and columns:
 342            try:
 343                l = first(largs - columns)
 344                r = first(rargs - columns)
 345            except StopIteration:
 346                return expression
 347
 348            if l.is_number and r.is_number:
 349                l = l.to_py()
 350                r = r.to_py()
 351            elif l.is_string and r.is_string:
 352                l = l.name
 353                r = r.name
 354            else:
 355                l = extract_date(l)
 356                if not l:
 357                    return None
 358                r = extract_date(r)
 359                if not r:
 360                    return None
 361                # python won't compare date and datetime, but many engines will upcast
 362                l, r = cast_as_datetime(l), cast_as_datetime(r)
 363
 364            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 365                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 366                    return left if (av > bv if or_ else av <= bv) else right
 367                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 368                    return left if (av < bv if or_ else av >= bv) else right
 369
 370                # we can't ever shortcut to true because the column could be null
 371                if not or_:
 372                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 373                        if av <= bv:
 374                            return exp.false()
 375                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 376                        if av >= bv:
 377                            return exp.false()
 378                    elif isinstance(a, exp.EQ):
 379                        if isinstance(b, exp.LT):
 380                            return exp.false() if av >= bv else a
 381                        if isinstance(b, exp.LTE):
 382                            return exp.false() if av > bv else a
 383                        if isinstance(b, exp.GT):
 384                            return exp.false() if av <= bv else a
 385                        if isinstance(b, exp.GTE):
 386                            return exp.false() if av < bv else a
 387                        if isinstance(b, exp.NEQ):
 388                            return exp.false() if av == bv else a
 389    return None
 390
 391
 392def remove_complements(expression, root=True):
 393    """
 394    Removing complements.
 395
 396    A AND NOT A -> FALSE
 397    A OR NOT A -> TRUE
 398    """
 399    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 400        ops = set(expression.flatten())
 401        for op in ops:
 402            if isinstance(op, exp.Not) and op.this in ops:
 403                return exp.false() if isinstance(expression, exp.And) else exp.true()
 404
 405    return expression
 406
 407
 408def uniq_sort(expression, root=True):
 409    """
 410    Uniq and sort a connector.
 411
 412    C AND A AND B AND B -> A AND B AND C
 413    """
 414    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 415        flattened = tuple(expression.flatten())
 416
 417        if isinstance(expression, exp.Xor):
 418            result_func = exp.xor
 419            # Do not deduplicate XOR as A XOR A != A if A == True
 420            deduped = None
 421            arr = tuple((gen(e), e) for e in flattened)
 422        else:
 423            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 424            deduped = {gen(e): e for e in flattened}
 425            arr = tuple(deduped.items())
 426
 427        # check if the operands are already sorted, if not sort them
 428        # A AND C AND B -> A AND B AND C
 429        for i, (sql, e) in enumerate(arr[1:]):
 430            if sql < arr[i][0]:
 431                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 432                break
 433        else:
 434            # we didn't have to sort but maybe we need to dedup
 435            if deduped and len(deduped) < len(flattened):
 436                expression = result_func(*deduped.values(), copy=False)
 437
 438    return expression
 439
 440
 441def absorb_and_eliminate(expression, root=True):
 442    """
 443    absorption:
 444        A AND (A OR B) -> A
 445        A OR (A AND B) -> A
 446        A AND (NOT A OR B) -> A AND B
 447        A OR (NOT A AND B) -> A OR B
 448    elimination:
 449        (A AND B) OR (A AND NOT B) -> A
 450        (A OR B) AND (A OR NOT B) -> A
 451    """
 452    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 453        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 454
 455        ops = tuple(expression.flatten())
 456
 457        # Initialize lookup tables:
 458        # Set of all operands, used to find complements for absorption.
 459        op_set = set()
 460        # Sub-operands, used to find subsets for absorption.
 461        subops = defaultdict(list)
 462        # Pairs of complements, used for elimination.
 463        pairs = defaultdict(list)
 464
 465        # Populate the lookup tables
 466        for op in ops:
 467            op_set.add(op)
 468
 469            if not isinstance(op, kind):
 470                # In cases like: A OR (A AND B)
 471                # Subop will be: ^
 472                subops[op].append({op})
 473                continue
 474
 475            # In cases like: (A AND B) OR (A AND B AND C)
 476            # Subops will be: ^     ^
 477            subset = set(op.flatten())
 478            for i in subset:
 479                subops[i].append(subset)
 480
 481            a, b = op.unnest_operands()
 482            if isinstance(a, exp.Not):
 483                pairs[frozenset((a.this, b))].append((op, b))
 484            if isinstance(b, exp.Not):
 485                pairs[frozenset((a, b.this))].append((op, a))
 486
 487        for op in ops:
 488            if not isinstance(op, kind):
 489                continue
 490
 491            a, b = op.unnest_operands()
 492
 493            # Absorb
 494            if isinstance(a, exp.Not) and a.this in op_set:
 495                a.replace(exp.true() if kind == exp.And else exp.false())
 496                continue
 497            if isinstance(b, exp.Not) and b.this in op_set:
 498                b.replace(exp.true() if kind == exp.And else exp.false())
 499                continue
 500            superset = set(op.flatten())
 501            if any(any(subset < superset for subset in subops[i]) for i in superset):
 502                op.replace(exp.false() if kind == exp.And else exp.true())
 503                continue
 504
 505            # Eliminate
 506            for other, complement in pairs[frozenset((a, b))]:
 507                op.replace(complement)
 508                other.replace(complement)
 509
 510    return expression
 511
 512
 513def propagate_constants(expression, root=True):
 514    """
 515    Propagate constants for conjunctions in DNF:
 516
 517    SELECT * FROM t WHERE a = b AND b = 5 becomes
 518    SELECT * FROM t WHERE a = 5 AND b = 5
 519
 520    Reference: https://www.sqlite.org/optoverview.html
 521    """
 522
 523    if (
 524        isinstance(expression, exp.And)
 525        and (root or not expression.same_parent)
 526        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 527    ):
 528        constant_mapping = {}
 529        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 530            if isinstance(expr, exp.EQ):
 531                l, r = expr.left, expr.right
 532
 533                # TODO: create a helper that can be used to detect nested literal expressions such
 534                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 535                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 536                    constant_mapping[l] = (id(l), r)
 537
 538        if constant_mapping:
 539            for column in find_all_in_scope(expression, exp.Column):
 540                parent = column.parent
 541                column_id, constant = constant_mapping.get(column) or (None, None)
 542                if (
 543                    column_id is not None
 544                    and id(column) != column_id
 545                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 546                ):
 547                    column.replace(constant.copy())
 548
 549    return expression
 550
 551
 552INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 553    exp.DateAdd: exp.Sub,
 554    exp.DateSub: exp.Add,
 555    exp.DatetimeAdd: exp.Sub,
 556    exp.DatetimeSub: exp.Add,
 557}
 558
 559INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 560    **INVERSE_DATE_OPS,
 561    exp.Add: exp.Sub,
 562    exp.Sub: exp.Add,
 563}
 564
 565
 566def _is_number(expression: exp.Expression) -> bool:
 567    return expression.is_number
 568
 569
 570def _is_interval(expression: exp.Expression) -> bool:
 571    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 572
 573
 574@catch(ModuleNotFoundError, UnsupportedUnit)
 575def simplify_equality(expression: exp.Expression) -> exp.Expression:
 576    """
 577    Use the subtraction and addition properties of equality to simplify expressions:
 578
 579        x + 1 = 3 becomes x = 2
 580
 581    There are two binary operations in the above expression: + and =
 582    Here's how we reference all the operands in the code below:
 583
 584          l     r
 585        x + 1 = 3
 586        a   b
 587    """
 588    if isinstance(expression, COMPARISONS):
 589        l, r = expression.left, expression.right
 590
 591        if l.__class__ not in INVERSE_OPS:
 592            return expression
 593
 594        if r.is_number:
 595            a_predicate = _is_number
 596            b_predicate = _is_number
 597        elif _is_date_literal(r):
 598            a_predicate = _is_date_literal
 599            b_predicate = _is_interval
 600        else:
 601            return expression
 602
 603        if l.__class__ in INVERSE_DATE_OPS:
 604            l = t.cast(exp.IntervalOp, l)
 605            a = l.this
 606            b = l.interval()
 607        else:
 608            l = t.cast(exp.Binary, l)
 609            a, b = l.left, l.right
 610
 611        if not a_predicate(a) and b_predicate(b):
 612            pass
 613        elif not a_predicate(b) and b_predicate(a):
 614            a, b = b, a
 615        else:
 616            return expression
 617
 618        return expression.__class__(
 619            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 620        )
 621    return expression
 622
 623
 624def simplify_literals(expression, root=True):
 625    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 626        return _flat_simplify(expression, _simplify_binary, root)
 627
 628    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
 629        return expression.this.this
 630
 631    if type(expression) in INVERSE_DATE_OPS:
 632        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 633
 634    return expression
 635
 636
 637NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 638
 639
 640def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 641    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 642        this = _simplify_integer_cast(expr.this)
 643    else:
 644        this = expr.this
 645
 646    if isinstance(expr, exp.Cast) and this.is_int:
 647        num = this.to_py()
 648
 649        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 650        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 651        # engine-dependent
 652        if (
 653            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 654        ) or (
 655            UTINYINT_MIN <= num <= UTINYINT_MAX
 656            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 657        ):
 658            return this
 659
 660    return expr
 661
 662
 663def _simplify_binary(expression, a, b):
 664    if isinstance(expression, COMPARISONS):
 665        a = _simplify_integer_cast(a)
 666        b = _simplify_integer_cast(b)
 667
 668    if isinstance(expression, exp.Is):
 669        if isinstance(b, exp.Not):
 670            c = b.this
 671            not_ = True
 672        else:
 673            c = b
 674            not_ = False
 675
 676        if is_null(c):
 677            if isinstance(a, exp.Literal):
 678                return exp.true() if not_ else exp.false()
 679            if is_null(a):
 680                return exp.false() if not_ else exp.true()
 681    elif isinstance(expression, NULL_OK):
 682        return None
 683    elif is_null(a) or is_null(b):
 684        return exp.null()
 685
 686    if a.is_number and b.is_number:
 687        num_a = a.to_py()
 688        num_b = b.to_py()
 689
 690        if isinstance(expression, exp.Add):
 691            return exp.Literal.number(num_a + num_b)
 692        if isinstance(expression, exp.Mul):
 693            return exp.Literal.number(num_a * num_b)
 694
 695        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 696        if isinstance(expression, exp.Sub):
 697            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 698        if isinstance(expression, exp.Div):
 699            # engines have differing int div behavior so intdiv is not safe
 700            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 701                return None
 702            return exp.Literal.number(num_a / num_b)
 703
 704        boolean = eval_boolean(expression, num_a, num_b)
 705
 706        if boolean:
 707            return boolean
 708    elif a.is_string and b.is_string:
 709        boolean = eval_boolean(expression, a.this, b.this)
 710
 711        if boolean:
 712            return boolean
 713    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 714        date, b = extract_date(a), extract_interval(b)
 715        if date and b:
 716            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 717                return date_literal(date + b, extract_type(a))
 718            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 719                return date_literal(date - b, extract_type(a))
 720    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 721        a, date = extract_interval(a), extract_date(b)
 722        # you cannot subtract a date from an interval
 723        if a and b and isinstance(expression, exp.Add):
 724            return date_literal(a + date, extract_type(b))
 725    elif _is_date_literal(a) and _is_date_literal(b):
 726        if isinstance(expression, exp.Predicate):
 727            a, b = extract_date(a), extract_date(b)
 728            boolean = eval_boolean(expression, a, b)
 729            if boolean:
 730                return boolean
 731
 732    return None
 733
 734
 735def simplify_parens(expression):
 736    if not isinstance(expression, exp.Paren):
 737        return expression
 738
 739    this = expression.this
 740    parent = expression.parent
 741    parent_is_predicate = isinstance(parent, exp.Predicate)
 742
 743    if (
 744        not isinstance(this, exp.Select)
 745        and not isinstance(parent, exp.SubqueryPredicate)
 746        and (
 747            not isinstance(parent, (exp.Condition, exp.Binary))
 748            or isinstance(parent, exp.Paren)
 749            or (
 750                not isinstance(this, exp.Binary)
 751                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 752            )
 753            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 754            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 755            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 756            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 757        )
 758    ):
 759        return this
 760    return expression
 761
 762
 763def _is_nonnull_constant(expression: exp.Expression) -> bool:
 764    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 765
 766
 767def _is_constant(expression: exp.Expression) -> bool:
 768    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 769
 770
 771def simplify_coalesce(expression):
 772    # COALESCE(x) -> x
 773    if (
 774        isinstance(expression, exp.Coalesce)
 775        and (not expression.expressions or _is_nonnull_constant(expression.this))
 776        # COALESCE is also used as a Spark partitioning hint
 777        and not isinstance(expression.parent, exp.Hint)
 778    ):
 779        return expression.this
 780
 781    if not isinstance(expression, COMPARISONS):
 782        return expression
 783
 784    if isinstance(expression.left, exp.Coalesce):
 785        coalesce = expression.left
 786        other = expression.right
 787    elif isinstance(expression.right, exp.Coalesce):
 788        coalesce = expression.right
 789        other = expression.left
 790    else:
 791        return expression
 792
 793    # This transformation is valid for non-constants,
 794    # but it really only does anything if they are both constants.
 795    if not _is_constant(other):
 796        return expression
 797
 798    # Find the first constant arg
 799    for arg_index, arg in enumerate(coalesce.expressions):
 800        if _is_constant(arg):
 801            break
 802    else:
 803        return expression
 804
 805    coalesce.set("expressions", coalesce.expressions[:arg_index])
 806
 807    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 808    # since we already remove COALESCE at the top of this function.
 809    coalesce = coalesce if coalesce.expressions else coalesce.this
 810
 811    # This expression is more complex than when we started, but it will get simplified further
 812    return exp.paren(
 813        exp.or_(
 814            exp.and_(
 815                coalesce.is_(exp.null()).not_(copy=False),
 816                expression.copy(),
 817                copy=False,
 818            ),
 819            exp.and_(
 820                coalesce.is_(exp.null()),
 821                type(expression)(this=arg.copy(), expression=other.copy()),
 822                copy=False,
 823            ),
 824            copy=False,
 825        )
 826    )
 827
 828
 829CONCATS = (exp.Concat, exp.DPipe)
 830
 831
 832def simplify_concat(expression):
 833    """Reduces all groups that contain string literals by concatenating them."""
 834    if not isinstance(expression, CONCATS) or (
 835        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 836        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 837    ):
 838        return expression
 839
 840    if isinstance(expression, exp.ConcatWs):
 841        sep_expr, *expressions = expression.expressions
 842        sep = sep_expr.name
 843        concat_type = exp.ConcatWs
 844        args = {}
 845    else:
 846        expressions = expression.expressions
 847        sep = ""
 848        concat_type = exp.Concat
 849        args = {
 850            "safe": expression.args.get("safe"),
 851            "coalesce": expression.args.get("coalesce"),
 852        }
 853
 854    new_args = []
 855    for is_string_group, group in itertools.groupby(
 856        expressions or expression.flatten(), lambda e: e.is_string
 857    ):
 858        if is_string_group:
 859            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 860        else:
 861            new_args.extend(group)
 862
 863    if len(new_args) == 1 and new_args[0].is_string:
 864        return new_args[0]
 865
 866    if concat_type is exp.ConcatWs:
 867        new_args = [sep_expr] + new_args
 868    elif isinstance(expression, exp.DPipe):
 869        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 870
 871    return concat_type(expressions=new_args, **args)
 872
 873
 874def simplify_conditionals(expression):
 875    """Simplifies expressions like IF, CASE if their condition is statically known."""
 876    if isinstance(expression, exp.Case):
 877        this = expression.this
 878        for case in expression.args["ifs"]:
 879            cond = case.this
 880            if this:
 881                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 882                cond = cond.replace(this.pop().eq(cond))
 883
 884            if always_true(cond):
 885                return case.args["true"]
 886
 887            if always_false(cond):
 888                case.pop()
 889                if not expression.args["ifs"]:
 890                    return expression.args.get("default") or exp.null()
 891    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 892        if always_true(expression.this):
 893            return expression.args["true"]
 894        if always_false(expression.this):
 895            return expression.args.get("false") or exp.null()
 896
 897    return expression
 898
 899
 900def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 901    """
 902    Reduces a prefix check to either TRUE or FALSE if both the string and the
 903    prefix are statically known.
 904
 905    Example:
 906        >>> from sqlglot import parse_one
 907        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 908        'TRUE'
 909    """
 910    if (
 911        isinstance(expression, exp.StartsWith)
 912        and expression.this.is_string
 913        and expression.expression.is_string
 914    ):
 915        return exp.convert(expression.name.startswith(expression.expression.name))
 916
 917    return expression
 918
 919
 920DateRange = t.Tuple[datetime.date, datetime.date]
 921
 922
 923def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 924    """
 925    Get the date range for a DATE_TRUNC equality comparison:
 926
 927    Example:
 928        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 929    Returns:
 930        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 931    """
 932    floor = date_floor(date, unit, dialect)
 933
 934    if date != floor:
 935        # This will always be False, except for NULL values.
 936        return None
 937
 938    return floor, floor + interval(unit)
 939
 940
 941def _datetrunc_eq_expression(
 942    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 943) -> exp.Expression:
 944    """Get the logical expression for a date range"""
 945    return exp.and_(
 946        left >= date_literal(drange[0], target_type),
 947        left < date_literal(drange[1], target_type),
 948        copy=False,
 949    )
 950
 951
 952def _datetrunc_eq(
 953    left: exp.Expression,
 954    date: datetime.date,
 955    unit: str,
 956    dialect: Dialect,
 957    target_type: t.Optional[exp.DataType],
 958) -> t.Optional[exp.Expression]:
 959    drange = _datetrunc_range(date, unit, dialect)
 960    if not drange:
 961        return None
 962
 963    return _datetrunc_eq_expression(left, drange, target_type)
 964
 965
 966def _datetrunc_neq(
 967    left: exp.Expression,
 968    date: datetime.date,
 969    unit: str,
 970    dialect: Dialect,
 971    target_type: t.Optional[exp.DataType],
 972) -> t.Optional[exp.Expression]:
 973    drange = _datetrunc_range(date, unit, dialect)
 974    if not drange:
 975        return None
 976
 977    return exp.and_(
 978        left < date_literal(drange[0], target_type),
 979        left >= date_literal(drange[1], target_type),
 980        copy=False,
 981    )
 982
 983
 984DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 985    exp.LT: lambda l, dt, u, d, t: l
 986    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 987    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 988    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 989    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 990    exp.EQ: _datetrunc_eq,
 991    exp.NEQ: _datetrunc_neq,
 992}
 993DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 994DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 995
 996
 997def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 998    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 999
1000
1001@catch(ModuleNotFoundError, UnsupportedUnit)
1002def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1003    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1004    comparison = expression.__class__
1005
1006    if isinstance(expression, DATETRUNCS):
1007        this = expression.this
1008        trunc_type = extract_type(this)
1009        date = extract_date(this)
1010        if date and expression.unit:
1011            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1012    elif comparison not in DATETRUNC_COMPARISONS:
1013        return expression
1014
1015    if isinstance(expression, exp.Binary):
1016        l, r = expression.left, expression.right
1017
1018        if not _is_datetrunc_predicate(l, r):
1019            return expression
1020
1021        l = t.cast(exp.DateTrunc, l)
1022        trunc_arg = l.this
1023        unit = l.unit.name.lower()
1024        date = extract_date(r)
1025
1026        if not date:
1027            return expression
1028
1029        return (
1030            DATETRUNC_BINARY_COMPARISONS[comparison](
1031                trunc_arg, date, unit, dialect, extract_type(r)
1032            )
1033            or expression
1034        )
1035
1036    if isinstance(expression, exp.In):
1037        l = expression.this
1038        rs = expression.expressions
1039
1040        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1041            l = t.cast(exp.DateTrunc, l)
1042            unit = l.unit.name.lower()
1043
1044            ranges = []
1045            for r in rs:
1046                date = extract_date(r)
1047                if not date:
1048                    return expression
1049                drange = _datetrunc_range(date, unit, dialect)
1050                if drange:
1051                    ranges.append(drange)
1052
1053            if not ranges:
1054                return expression
1055
1056            ranges = merge_ranges(ranges)
1057            target_type = extract_type(*rs)
1058
1059            return exp.or_(
1060                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1061            )
1062
1063    return expression
1064
1065
1066def sort_comparison(expression: exp.Expression) -> exp.Expression:
1067    if expression.__class__ in COMPLEMENT_COMPARISONS:
1068        l, r = expression.this, expression.expression
1069        l_column = isinstance(l, exp.Column)
1070        r_column = isinstance(r, exp.Column)
1071        l_const = _is_constant(l)
1072        r_const = _is_constant(r)
1073
1074        if (l_column and not r_column) or (r_const and not l_const):
1075            return expression
1076        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1077            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1078                this=r, expression=l
1079            )
1080    return expression
1081
1082
1083# CROSS joins result in an empty table if the right table is empty.
1084# So we can only simplify certain types of joins to CROSS.
1085# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1086JOINS = {
1087    ("", ""),
1088    ("", "INNER"),
1089    ("RIGHT", ""),
1090    ("RIGHT", "OUTER"),
1091}
1092
1093
1094def remove_where_true(expression):
1095    for where in expression.find_all(exp.Where):
1096        if always_true(where.this):
1097            where.pop()
1098    for join in expression.find_all(exp.Join):
1099        if (
1100            always_true(join.args.get("on"))
1101            and not join.args.get("using")
1102            and not join.args.get("method")
1103            and (join.side, join.kind) in JOINS
1104        ):
1105            join.args["on"].pop()
1106            join.set("side", None)
1107            join.set("kind", "CROSS")
1108
1109
1110def always_true(expression):
1111    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1112        expression, exp.Literal
1113    )
1114
1115
1116def always_false(expression):
1117    return is_false(expression) or is_null(expression)
1118
1119
1120def is_complement(a, b):
1121    return isinstance(b, exp.Not) and b.this == a
1122
1123
1124def is_false(a: exp.Expression) -> bool:
1125    return type(a) is exp.Boolean and not a.this
1126
1127
1128def is_null(a: exp.Expression) -> bool:
1129    return type(a) is exp.Null
1130
1131
1132def eval_boolean(expression, a, b):
1133    if isinstance(expression, (exp.EQ, exp.Is)):
1134        return boolean_literal(a == b)
1135    if isinstance(expression, exp.NEQ):
1136        return boolean_literal(a != b)
1137    if isinstance(expression, exp.GT):
1138        return boolean_literal(a > b)
1139    if isinstance(expression, exp.GTE):
1140        return boolean_literal(a >= b)
1141    if isinstance(expression, exp.LT):
1142        return boolean_literal(a < b)
1143    if isinstance(expression, exp.LTE):
1144        return boolean_literal(a <= b)
1145    return None
1146
1147
1148def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1149    if isinstance(value, datetime.datetime):
1150        return value.date()
1151    if isinstance(value, datetime.date):
1152        return value
1153    try:
1154        return datetime.datetime.fromisoformat(value).date()
1155    except ValueError:
1156        return None
1157
1158
1159def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1160    if isinstance(value, datetime.datetime):
1161        return value
1162    if isinstance(value, datetime.date):
1163        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1164    try:
1165        return datetime.datetime.fromisoformat(value)
1166    except ValueError:
1167        return None
1168
1169
1170def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1171    if not value:
1172        return None
1173    if to.is_type(exp.DataType.Type.DATE):
1174        return cast_as_date(value)
1175    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1176        return cast_as_datetime(value)
1177    return None
1178
1179
1180def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1181    if isinstance(cast, exp.Cast):
1182        to = cast.to
1183    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1184        to = exp.DataType.build(exp.DataType.Type.DATE)
1185    else:
1186        return None
1187
1188    if isinstance(cast.this, exp.Literal):
1189        value: t.Any = cast.this.name
1190    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1191        value = extract_date(cast.this)
1192    else:
1193        return None
1194    return cast_value(value, to)
1195
1196
1197def _is_date_literal(expression: exp.Expression) -> bool:
1198    return extract_date(expression) is not None
1199
1200
1201def extract_interval(expression):
1202    try:
1203        n = int(expression.this.to_py())
1204        unit = expression.text("unit").lower()
1205        return interval(unit, n)
1206    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1207        return None
1208
1209
1210def extract_type(*expressions):
1211    target_type = None
1212    for expression in expressions:
1213        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1214        if target_type:
1215            break
1216
1217    return target_type
1218
1219
1220def date_literal(date, target_type=None):
1221    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1222        target_type = (
1223            exp.DataType.Type.DATETIME
1224            if isinstance(date, datetime.datetime)
1225            else exp.DataType.Type.DATE
1226        )
1227
1228    return exp.cast(exp.Literal.string(date), target_type)
1229
1230
1231def interval(unit: str, n: int = 1):
1232    from dateutil.relativedelta import relativedelta
1233
1234    if unit == "year":
1235        return relativedelta(years=1 * n)
1236    if unit == "quarter":
1237        return relativedelta(months=3 * n)
1238    if unit == "month":
1239        return relativedelta(months=1 * n)
1240    if unit == "week":
1241        return relativedelta(weeks=1 * n)
1242    if unit == "day":
1243        return relativedelta(days=1 * n)
1244    if unit == "hour":
1245        return relativedelta(hours=1 * n)
1246    if unit == "minute":
1247        return relativedelta(minutes=1 * n)
1248    if unit == "second":
1249        return relativedelta(seconds=1 * n)
1250
1251    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1252
1253
1254def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1255    if unit == "year":
1256        return d.replace(month=1, day=1)
1257    if unit == "quarter":
1258        if d.month <= 3:
1259            return d.replace(month=1, day=1)
1260        elif d.month <= 6:
1261            return d.replace(month=4, day=1)
1262        elif d.month <= 9:
1263            return d.replace(month=7, day=1)
1264        else:
1265            return d.replace(month=10, day=1)
1266    if unit == "month":
1267        return d.replace(month=d.month, day=1)
1268    if unit == "week":
1269        # Assuming week starts on Monday (0) and ends on Sunday (6)
1270        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1271    if unit == "day":
1272        return d
1273
1274    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1275
1276
1277def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1278    floor = date_floor(d, unit, dialect)
1279
1280    if floor == d:
1281        return d
1282
1283    return floor + interval(unit)
1284
1285
1286def boolean_literal(condition):
1287    return exp.true() if condition else exp.false()
1288
1289
1290def _flat_simplify(expression, simplifier, root=True):
1291    if root or not expression.same_parent:
1292        operands = []
1293        queue = deque(expression.flatten(unnest=False))
1294        size = len(queue)
1295
1296        while queue:
1297            a = queue.popleft()
1298
1299            for b in queue:
1300                result = simplifier(expression, a, b)
1301
1302                if result and result is not expression:
1303                    queue.remove(b)
1304                    queue.appendleft(result)
1305                    break
1306            else:
1307                operands.append(a)
1308
1309        if len(operands) < size:
1310            return functools.reduce(
1311                lambda a, b: expression.__class__(this=a, expression=b), operands
1312            )
1313    return expression
1314
1315
1316def gen(expression: t.Any) -> str:
1317    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1318
1319    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1320    generator is expensive so we have a bare minimum sql generator here.
1321    """
1322    return Gen().gen(expression)
1323
1324
1325class Gen:
1326    def __init__(self):
1327        self.stack = []
1328        self.sqls = []
1329
1330    def gen(self, expression: exp.Expression) -> str:
1331        self.stack = [expression]
1332        self.sqls.clear()
1333
1334        while self.stack:
1335            node = self.stack.pop()
1336
1337            if isinstance(node, exp.Expression):
1338                exp_handler_name = f"{node.key}_sql"
1339
1340                if hasattr(self, exp_handler_name):
1341                    getattr(self, exp_handler_name)(node)
1342                elif isinstance(node, exp.Func):
1343                    self._function(node)
1344                else:
1345                    key = node.key.upper()
1346                    self.stack.append(f"{key} " if self._args(node) else key)
1347            elif type(node) is list:
1348                for n in reversed(node):
1349                    if n is not None:
1350                        self.stack.extend((n, ","))
1351                if node:
1352                    self.stack.pop()
1353            else:
1354                if node is not None:
1355                    self.sqls.append(str(node))
1356
1357        return "".join(self.sqls)
1358
1359    def add_sql(self, e: exp.Add) -> None:
1360        self._binary(e, " + ")
1361
1362    def alias_sql(self, e: exp.Alias) -> None:
1363        self.stack.extend(
1364            (
1365                e.args.get("alias"),
1366                " AS ",
1367                e.args.get("this"),
1368            )
1369        )
1370
1371    def and_sql(self, e: exp.And) -> None:
1372        self._binary(e, " AND ")
1373
1374    def anonymous_sql(self, e: exp.Anonymous) -> None:
1375        this = e.this
1376        if isinstance(this, str):
1377            name = this.upper()
1378        elif isinstance(this, exp.Identifier):
1379            name = this.this
1380            name = f'"{name}"' if this.quoted else name.upper()
1381        else:
1382            raise ValueError(
1383                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1384            )
1385
1386        self.stack.extend(
1387            (
1388                ")",
1389                e.expressions,
1390                "(",
1391                name,
1392            )
1393        )
1394
1395    def between_sql(self, e: exp.Between) -> None:
1396        self.stack.extend(
1397            (
1398                e.args.get("high"),
1399                " AND ",
1400                e.args.get("low"),
1401                " BETWEEN ",
1402                e.this,
1403            )
1404        )
1405
1406    def boolean_sql(self, e: exp.Boolean) -> None:
1407        self.stack.append("TRUE" if e.this else "FALSE")
1408
1409    def bracket_sql(self, e: exp.Bracket) -> None:
1410        self.stack.extend(
1411            (
1412                "]",
1413                e.expressions,
1414                "[",
1415                e.this,
1416            )
1417        )
1418
1419    def column_sql(self, e: exp.Column) -> None:
1420        for p in reversed(e.parts):
1421            self.stack.extend((p, "."))
1422        self.stack.pop()
1423
1424    def datatype_sql(self, e: exp.DataType) -> None:
1425        self._args(e, 1)
1426        self.stack.append(f"{e.this.name} ")
1427
1428    def div_sql(self, e: exp.Div) -> None:
1429        self._binary(e, " / ")
1430
1431    def dot_sql(self, e: exp.Dot) -> None:
1432        self._binary(e, ".")
1433
1434    def eq_sql(self, e: exp.EQ) -> None:
1435        self._binary(e, " = ")
1436
1437    def from_sql(self, e: exp.From) -> None:
1438        self.stack.extend((e.this, "FROM "))
1439
1440    def gt_sql(self, e: exp.GT) -> None:
1441        self._binary(e, " > ")
1442
1443    def gte_sql(self, e: exp.GTE) -> None:
1444        self._binary(e, " >= ")
1445
1446    def identifier_sql(self, e: exp.Identifier) -> None:
1447        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1448
1449    def ilike_sql(self, e: exp.ILike) -> None:
1450        self._binary(e, " ILIKE ")
1451
1452    def in_sql(self, e: exp.In) -> None:
1453        self.stack.append(")")
1454        self._args(e, 1)
1455        self.stack.extend(
1456            (
1457                "(",
1458                " IN ",
1459                e.this,
1460            )
1461        )
1462
1463    def intdiv_sql(self, e: exp.IntDiv) -> None:
1464        self._binary(e, " DIV ")
1465
1466    def is_sql(self, e: exp.Is) -> None:
1467        self._binary(e, " IS ")
1468
1469    def like_sql(self, e: exp.Like) -> None:
1470        self._binary(e, " Like ")
1471
1472    def literal_sql(self, e: exp.Literal) -> None:
1473        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1474
1475    def lt_sql(self, e: exp.LT) -> None:
1476        self._binary(e, " < ")
1477
1478    def lte_sql(self, e: exp.LTE) -> None:
1479        self._binary(e, " <= ")
1480
1481    def mod_sql(self, e: exp.Mod) -> None:
1482        self._binary(e, " % ")
1483
1484    def mul_sql(self, e: exp.Mul) -> None:
1485        self._binary(e, " * ")
1486
1487    def neg_sql(self, e: exp.Neg) -> None:
1488        self._unary(e, "-")
1489
1490    def neq_sql(self, e: exp.NEQ) -> None:
1491        self._binary(e, " <> ")
1492
1493    def not_sql(self, e: exp.Not) -> None:
1494        self._unary(e, "NOT ")
1495
1496    def null_sql(self, e: exp.Null) -> None:
1497        self.stack.append("NULL")
1498
1499    def or_sql(self, e: exp.Or) -> None:
1500        self._binary(e, " OR ")
1501
1502    def paren_sql(self, e: exp.Paren) -> None:
1503        self.stack.extend(
1504            (
1505                ")",
1506                e.this,
1507                "(",
1508            )
1509        )
1510
1511    def sub_sql(self, e: exp.Sub) -> None:
1512        self._binary(e, " - ")
1513
1514    def subquery_sql(self, e: exp.Subquery) -> None:
1515        self._args(e, 2)
1516        alias = e.args.get("alias")
1517        if alias:
1518            self.stack.append(alias)
1519        self.stack.extend((")", e.this, "("))
1520
1521    def table_sql(self, e: exp.Table) -> None:
1522        self._args(e, 4)
1523        alias = e.args.get("alias")
1524        if alias:
1525            self.stack.append(alias)
1526        for p in reversed(e.parts):
1527            self.stack.extend((p, "."))
1528        self.stack.pop()
1529
1530    def tablealias_sql(self, e: exp.TableAlias) -> None:
1531        columns = e.columns
1532
1533        if columns:
1534            self.stack.extend((")", columns, "("))
1535
1536        self.stack.extend((e.this, " AS "))
1537
1538    def var_sql(self, e: exp.Var) -> None:
1539        self.stack.append(e.this)
1540
1541    def _binary(self, e: exp.Binary, op: str) -> None:
1542        self.stack.extend((e.expression, op, e.this))
1543
1544    def _unary(self, e: exp.Unary, op: str) -> None:
1545        self.stack.extend((e.this, op))
1546
1547    def _function(self, e: exp.Func) -> None:
1548        self.stack.extend(
1549            (
1550                ")",
1551                list(e.args.values()),
1552                "(",
1553                e.sql_name(),
1554            )
1555        )
1556
1557    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1558        kvs = []
1559        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1560
1561        for k in arg_types or arg_types:
1562            v = node.args.get(k)
1563
1564            if v is not None:
1565                kvs.append([f":{k}", v])
1566        if kvs:
1567            self.stack.append(kvs)
1568            return True
1569        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
36class UnsupportedUnit(Exception):
37    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, max_depth: Optional[int] = None):
 40def simplify(
 41    expression: exp.Expression,
 42    constant_propagation: bool = False,
 43    dialect: DialectType = None,
 44    max_depth: t.Optional[int] = None,
 45):
 46    """
 47    Rewrite sqlglot AST to simplify expressions.
 48
 49    Example:
 50        >>> import sqlglot
 51        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 52        >>> simplify(expression).sql()
 53        'TRUE'
 54
 55    Args:
 56        expression: expression to simplify
 57        constant_propagation: whether the constant propagation rule should be used
 58        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
 59    Returns:
 60        sqlglot.Expression: simplified expression
 61    """
 62
 63    dialect = Dialect.get_or_raise(dialect)
 64
 65    def _simplify(expression, root=True):
 66        if (
 67            max_depth
 68            and isinstance(expression, exp.Connector)
 69            and not isinstance(expression.parent, exp.Connector)
 70        ):
 71            depth = connector_depth(expression)
 72            if depth > max_depth:
 73                logger.info(
 74                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
 75                )
 76                return expression
 77
 78        if expression.meta.get(FINAL):
 79            return expression
 80
 81        # group by expressions cannot be simplified, for example
 82        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 83        # the projection must exactly match the group by key
 84        group = expression.args.get("group")
 85
 86        if group and hasattr(expression, "selects"):
 87            groups = set(group.expressions)
 88            group.meta[FINAL] = True
 89
 90            for e in expression.selects:
 91                for node in e.walk():
 92                    if node in groups:
 93                        e.meta[FINAL] = True
 94                        break
 95
 96            having = expression.args.get("having")
 97            if having:
 98                for node in having.walk():
 99                    if node in groups:
100                        having.meta[FINAL] = True
101                        break
102
103        # Pre-order transformations
104        node = expression
105        node = rewrite_between(node)
106        node = uniq_sort(node, root)
107        node = absorb_and_eliminate(node, root)
108        node = simplify_concat(node)
109        node = simplify_conditionals(node)
110
111        if constant_propagation:
112            node = propagate_constants(node, root)
113
114        exp.replace_children(node, lambda e: _simplify(e, False))
115
116        # Post-order transformations
117        node = simplify_not(node)
118        node = flatten(node)
119        node = simplify_connectors(node, root)
120        node = remove_complements(node, root)
121        node = simplify_coalesce(node)
122        node.parent = expression.parent
123        node = simplify_literals(node, root)
124        node = simplify_equality(node)
125        node = simplify_parens(node)
126        node = simplify_datetrunc(node, dialect)
127        node = sort_comparison(node)
128        node = simplify_startswith(node)
129
130        if root:
131            expression.replace(node)
132        return node
133
134    expression = while_changing(expression, _simplify)
135    remove_where_true(expression)
136    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • max_depth: Chains of Connectors (AND, OR, etc) exceeding max_depth will be skipped
Returns:

sqlglot.Expression: simplified expression

def connector_depth(expression: sqlglot.expressions.Expression) -> int:
139def connector_depth(expression: exp.Expression) -> int:
140    """
141    Determine the maximum depth of a tree of Connectors.
142
143    For example:
144        >>> from sqlglot import parse_one
145        >>> connector_depth(parse_one("a AND b AND c AND d"))
146        3
147    """
148    stack = deque([(expression, 0)])
149    max_depth = 0
150
151    while stack:
152        expression, depth = stack.pop()
153
154        if not isinstance(expression, exp.Connector):
155            continue
156
157        depth += 1
158        max_depth = max(depth, max_depth)
159
160        stack.append((expression.left, depth))
161        stack.append((expression.right, depth))
162
163    return max_depth

Determine the maximum depth of a tree of Connectors.

For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
def catch(*exceptions):
166def catch(*exceptions):
167    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
168
169    def decorator(func):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
175
176        return wrapped
177
178    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
181def rewrite_between(expression: exp.Expression) -> exp.Expression:
182    """Rewrite x between y and z to x >= y AND x <= z.
183
184    This is done because comparison simplification is only done on lt/lte/gt/gte.
185    """
186    if isinstance(expression, exp.Between):
187        negate = isinstance(expression.parent, exp.Not)
188
189        expression = exp.and_(
190            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
191            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
192            copy=False,
193        )
194
195        if negate:
196            expression = exp.paren(expression, copy=False)
197
198    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
211def simplify_not(expression):
212    """
213    Demorgan's Law
214    NOT (x OR y) -> NOT x AND NOT y
215    NOT (x AND y) -> NOT x OR NOT y
216    """
217    if isinstance(expression, exp.Not):
218        this = expression.this
219        if is_null(this):
220            return exp.null()
221        if this.__class__ in COMPLEMENT_COMPARISONS:
222            return COMPLEMENT_COMPARISONS[this.__class__](
223                this=this.this, expression=this.expression
224            )
225        if isinstance(this, exp.Paren):
226            condition = this.unnest()
227            if isinstance(condition, exp.And):
228                return exp.paren(
229                    exp.or_(
230                        exp.not_(condition.left, copy=False),
231                        exp.not_(condition.right, copy=False),
232                        copy=False,
233                    )
234                )
235            if isinstance(condition, exp.Or):
236                return exp.paren(
237                    exp.and_(
238                        exp.not_(condition.left, copy=False),
239                        exp.not_(condition.right, copy=False),
240                        copy=False,
241                    )
242                )
243            if is_null(condition):
244                return exp.null()
245        if always_true(this):
246            return exp.false()
247        if is_false(this):
248            return exp.true()
249        if isinstance(this, exp.Not):
250            # double negation
251            # NOT NOT x -> x
252            return this.this
253    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
256def flatten(expression):
257    """
258    A AND (B AND C) -> A AND B AND C
259    A OR (B OR C) -> A OR B OR C
260    """
261    if isinstance(expression, exp.Connector):
262        for node in expression.args.values():
263            child = node.unnest()
264            if isinstance(child, expression.__class__):
265                node.replace(child)
266    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
269def simplify_connectors(expression, root=True):
270    def _simplify_connectors(expression, left, right):
271        if left == right:
272            if isinstance(expression, exp.Xor):
273                return exp.false()
274            return left
275        if isinstance(expression, exp.And):
276            if is_false(left) or is_false(right):
277                return exp.false()
278            if is_null(left) or is_null(right):
279                return exp.null()
280            if always_true(left) and always_true(right):
281                return exp.true()
282            if always_true(left):
283                return right
284            if always_true(right):
285                return left
286            return _simplify_comparison(expression, left, right)
287        elif isinstance(expression, exp.Or):
288            if always_true(left) or always_true(right):
289                return exp.true()
290            if is_false(left) and is_false(right):
291                return exp.false()
292            if (
293                (is_null(left) and is_null(right))
294                or (is_null(left) and is_false(right))
295                or (is_false(left) and is_null(right))
296            ):
297                return exp.null()
298            if is_false(left):
299                return right
300            if is_false(right):
301                return left
302            return _simplify_comparison(expression, left, right, or_=True)
303
304    if isinstance(expression, exp.Connector):
305        return _flat_simplify(expression, _simplify_connectors, root)
306    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
393def remove_complements(expression, root=True):
394    """
395    Removing complements.
396
397    A AND NOT A -> FALSE
398    A OR NOT A -> TRUE
399    """
400    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
401        ops = set(expression.flatten())
402        for op in ops:
403            if isinstance(op, exp.Not) and op.this in ops:
404                return exp.false() if isinstance(expression, exp.And) else exp.true()
405
406    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
409def uniq_sort(expression, root=True):
410    """
411    Uniq and sort a connector.
412
413    C AND A AND B AND B -> A AND B AND C
414    """
415    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
416        flattened = tuple(expression.flatten())
417
418        if isinstance(expression, exp.Xor):
419            result_func = exp.xor
420            # Do not deduplicate XOR as A XOR A != A if A == True
421            deduped = None
422            arr = tuple((gen(e), e) for e in flattened)
423        else:
424            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
425            deduped = {gen(e): e for e in flattened}
426            arr = tuple(deduped.items())
427
428        # check if the operands are already sorted, if not sort them
429        # A AND C AND B -> A AND B AND C
430        for i, (sql, e) in enumerate(arr[1:]):
431            if sql < arr[i][0]:
432                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
433                break
434        else:
435            # we didn't have to sort but maybe we need to dedup
436            if deduped and len(deduped) < len(flattened):
437                expression = result_func(*deduped.values(), copy=False)
438
439    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
442def absorb_and_eliminate(expression, root=True):
443    """
444    absorption:
445        A AND (A OR B) -> A
446        A OR (A AND B) -> A
447        A AND (NOT A OR B) -> A AND B
448        A OR (NOT A AND B) -> A OR B
449    elimination:
450        (A AND B) OR (A AND NOT B) -> A
451        (A OR B) AND (A OR NOT B) -> A
452    """
453    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
454        kind = exp.Or if isinstance(expression, exp.And) else exp.And
455
456        ops = tuple(expression.flatten())
457
458        # Initialize lookup tables:
459        # Set of all operands, used to find complements for absorption.
460        op_set = set()
461        # Sub-operands, used to find subsets for absorption.
462        subops = defaultdict(list)
463        # Pairs of complements, used for elimination.
464        pairs = defaultdict(list)
465
466        # Populate the lookup tables
467        for op in ops:
468            op_set.add(op)
469
470            if not isinstance(op, kind):
471                # In cases like: A OR (A AND B)
472                # Subop will be: ^
473                subops[op].append({op})
474                continue
475
476            # In cases like: (A AND B) OR (A AND B AND C)
477            # Subops will be: ^     ^
478            subset = set(op.flatten())
479            for i in subset:
480                subops[i].append(subset)
481
482            a, b = op.unnest_operands()
483            if isinstance(a, exp.Not):
484                pairs[frozenset((a.this, b))].append((op, b))
485            if isinstance(b, exp.Not):
486                pairs[frozenset((a, b.this))].append((op, a))
487
488        for op in ops:
489            if not isinstance(op, kind):
490                continue
491
492            a, b = op.unnest_operands()
493
494            # Absorb
495            if isinstance(a, exp.Not) and a.this in op_set:
496                a.replace(exp.true() if kind == exp.And else exp.false())
497                continue
498            if isinstance(b, exp.Not) and b.this in op_set:
499                b.replace(exp.true() if kind == exp.And else exp.false())
500                continue
501            superset = set(op.flatten())
502            if any(any(subset < superset for subset in subops[i]) for i in superset):
503                op.replace(exp.false() if kind == exp.And else exp.true())
504                continue
505
506            # Eliminate
507            for other, complement in pairs[frozenset((a, b))]:
508                op.replace(complement)
509                other.replace(complement)
510
511    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
514def propagate_constants(expression, root=True):
515    """
516    Propagate constants for conjunctions in DNF:
517
518    SELECT * FROM t WHERE a = b AND b = 5 becomes
519    SELECT * FROM t WHERE a = 5 AND b = 5
520
521    Reference: https://www.sqlite.org/optoverview.html
522    """
523
524    if (
525        isinstance(expression, exp.And)
526        and (root or not expression.same_parent)
527        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
528    ):
529        constant_mapping = {}
530        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
531            if isinstance(expr, exp.EQ):
532                l, r = expr.left, expr.right
533
534                # TODO: create a helper that can be used to detect nested literal expressions such
535                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
536                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
537                    constant_mapping[l] = (id(l), r)
538
539        if constant_mapping:
540            for column in find_all_in_scope(expression, exp.Column):
541                parent = column.parent
542                column_id, constant = constant_mapping.get(column) or (None, None)
543                if (
544                    column_id is not None
545                    and id(column) != column_id
546                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
547                ):
548                    column.replace(constant.copy())
549
550    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
625def simplify_literals(expression, root=True):
626    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
627        return _flat_simplify(expression, _simplify_binary, root)
628
629    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
630        return expression.this.this
631
632    if type(expression) in INVERSE_DATE_OPS:
633        return _simplify_binary(expression, expression.this, expression.interval()) or expression
634
635    return expression
def simplify_parens(expression):
736def simplify_parens(expression):
737    if not isinstance(expression, exp.Paren):
738        return expression
739
740    this = expression.this
741    parent = expression.parent
742    parent_is_predicate = isinstance(parent, exp.Predicate)
743
744    if (
745        not isinstance(this, exp.Select)
746        and not isinstance(parent, exp.SubqueryPredicate)
747        and (
748            not isinstance(parent, (exp.Condition, exp.Binary))
749            or isinstance(parent, exp.Paren)
750            or (
751                not isinstance(this, exp.Binary)
752                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
753            )
754            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
755            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
756            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
757            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
758        )
759    ):
760        return this
761    return expression
def simplify_coalesce(expression):
772def simplify_coalesce(expression):
773    # COALESCE(x) -> x
774    if (
775        isinstance(expression, exp.Coalesce)
776        and (not expression.expressions or _is_nonnull_constant(expression.this))
777        # COALESCE is also used as a Spark partitioning hint
778        and not isinstance(expression.parent, exp.Hint)
779    ):
780        return expression.this
781
782    if not isinstance(expression, COMPARISONS):
783        return expression
784
785    if isinstance(expression.left, exp.Coalesce):
786        coalesce = expression.left
787        other = expression.right
788    elif isinstance(expression.right, exp.Coalesce):
789        coalesce = expression.right
790        other = expression.left
791    else:
792        return expression
793
794    # This transformation is valid for non-constants,
795    # but it really only does anything if they are both constants.
796    if not _is_constant(other):
797        return expression
798
799    # Find the first constant arg
800    for arg_index, arg in enumerate(coalesce.expressions):
801        if _is_constant(arg):
802            break
803    else:
804        return expression
805
806    coalesce.set("expressions", coalesce.expressions[:arg_index])
807
808    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
809    # since we already remove COALESCE at the top of this function.
810    coalesce = coalesce if coalesce.expressions else coalesce.this
811
812    # This expression is more complex than when we started, but it will get simplified further
813    return exp.paren(
814        exp.or_(
815            exp.and_(
816                coalesce.is_(exp.null()).not_(copy=False),
817                expression.copy(),
818                copy=False,
819            ),
820            exp.and_(
821                coalesce.is_(exp.null()),
822                type(expression)(this=arg.copy(), expression=other.copy()),
823                copy=False,
824            ),
825            copy=False,
826        )
827    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
833def simplify_concat(expression):
834    """Reduces all groups that contain string literals by concatenating them."""
835    if not isinstance(expression, CONCATS) or (
836        # We can't reduce a CONCAT_WS call if we don't statically know the separator
837        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
838    ):
839        return expression
840
841    if isinstance(expression, exp.ConcatWs):
842        sep_expr, *expressions = expression.expressions
843        sep = sep_expr.name
844        concat_type = exp.ConcatWs
845        args = {}
846    else:
847        expressions = expression.expressions
848        sep = ""
849        concat_type = exp.Concat
850        args = {
851            "safe": expression.args.get("safe"),
852            "coalesce": expression.args.get("coalesce"),
853        }
854
855    new_args = []
856    for is_string_group, group in itertools.groupby(
857        expressions or expression.flatten(), lambda e: e.is_string
858    ):
859        if is_string_group:
860            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
861        else:
862            new_args.extend(group)
863
864    if len(new_args) == 1 and new_args[0].is_string:
865        return new_args[0]
866
867    if concat_type is exp.ConcatWs:
868        new_args = [sep_expr] + new_args
869    elif isinstance(expression, exp.DPipe):
870        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
871
872    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
875def simplify_conditionals(expression):
876    """Simplifies expressions like IF, CASE if their condition is statically known."""
877    if isinstance(expression, exp.Case):
878        this = expression.this
879        for case in expression.args["ifs"]:
880            cond = case.this
881            if this:
882                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
883                cond = cond.replace(this.pop().eq(cond))
884
885            if always_true(cond):
886                return case.args["true"]
887
888            if always_false(cond):
889                case.pop()
890                if not expression.args["ifs"]:
891                    return expression.args.get("default") or exp.null()
892    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
893        if always_true(expression.this):
894            return expression.args["true"]
895        if always_false(expression.this):
896            return expression.args.get("false") or exp.null()
897
898    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
901def simplify_startswith(expression: exp.Expression) -> exp.Expression:
902    """
903    Reduces a prefix check to either TRUE or FALSE if both the string and the
904    prefix are statically known.
905
906    Example:
907        >>> from sqlglot import parse_one
908        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
909        'TRUE'
910    """
911    if (
912        isinstance(expression, exp.StartsWith)
913        and expression.this.is_string
914        and expression.expression.is_string
915    ):
916        return exp.convert(expression.name.startswith(expression.expression.name))
917
918    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1067def sort_comparison(expression: exp.Expression) -> exp.Expression:
1068    if expression.__class__ in COMPLEMENT_COMPARISONS:
1069        l, r = expression.this, expression.expression
1070        l_column = isinstance(l, exp.Column)
1071        r_column = isinstance(r, exp.Column)
1072        l_const = _is_constant(l)
1073        r_const = _is_constant(r)
1074
1075        if (l_column and not r_column) or (r_const and not l_const):
1076            return expression
1077        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1078            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1079                this=r, expression=l
1080            )
1081    return expression
JOINS = {('RIGHT', ''), ('', 'INNER'), ('', ''), ('RIGHT', 'OUTER')}
def remove_where_true(expression):
1095def remove_where_true(expression):
1096    for where in expression.find_all(exp.Where):
1097        if always_true(where.this):
1098            where.pop()
1099    for join in expression.find_all(exp.Join):
1100        if (
1101            always_true(join.args.get("on"))
1102            and not join.args.get("using")
1103            and not join.args.get("method")
1104            and (join.side, join.kind) in JOINS
1105        ):
1106            join.args["on"].pop()
1107            join.set("side", None)
1108            join.set("kind", "CROSS")
def always_true(expression):
1111def always_true(expression):
1112    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1113        expression, exp.Literal
1114    )
def always_false(expression):
1117def always_false(expression):
1118    return is_false(expression) or is_null(expression)
def is_complement(a, b):
1121def is_complement(a, b):
1122    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1125def is_false(a: exp.Expression) -> bool:
1126    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1129def is_null(a: exp.Expression) -> bool:
1130    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1133def eval_boolean(expression, a, b):
1134    if isinstance(expression, (exp.EQ, exp.Is)):
1135        return boolean_literal(a == b)
1136    if isinstance(expression, exp.NEQ):
1137        return boolean_literal(a != b)
1138    if isinstance(expression, exp.GT):
1139        return boolean_literal(a > b)
1140    if isinstance(expression, exp.GTE):
1141        return boolean_literal(a >= b)
1142    if isinstance(expression, exp.LT):
1143        return boolean_literal(a < b)
1144    if isinstance(expression, exp.LTE):
1145        return boolean_literal(a <= b)
1146    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1149def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1150    if isinstance(value, datetime.datetime):
1151        return value.date()
1152    if isinstance(value, datetime.date):
1153        return value
1154    try:
1155        return datetime.datetime.fromisoformat(value).date()
1156    except ValueError:
1157        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1160def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1161    if isinstance(value, datetime.datetime):
1162        return value
1163    if isinstance(value, datetime.date):
1164        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1165    try:
1166        return datetime.datetime.fromisoformat(value)
1167    except ValueError:
1168        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1171def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1172    if not value:
1173        return None
1174    if to.is_type(exp.DataType.Type.DATE):
1175        return cast_as_date(value)
1176    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1177        return cast_as_datetime(value)
1178    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1181def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1182    if isinstance(cast, exp.Cast):
1183        to = cast.to
1184    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1185        to = exp.DataType.build(exp.DataType.Type.DATE)
1186    else:
1187        return None
1188
1189    if isinstance(cast.this, exp.Literal):
1190        value: t.Any = cast.this.name
1191    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1192        value = extract_date(cast.this)
1193    else:
1194        return None
1195    return cast_value(value, to)
def extract_interval(expression):
1202def extract_interval(expression):
1203    try:
1204        n = int(expression.this.to_py())
1205        unit = expression.text("unit").lower()
1206        return interval(unit, n)
1207    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1208        return None
def extract_type(*expressions):
1211def extract_type(*expressions):
1212    target_type = None
1213    for expression in expressions:
1214        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1215        if target_type:
1216            break
1217
1218    return target_type
def date_literal(date, target_type=None):
1221def date_literal(date, target_type=None):
1222    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1223        target_type = (
1224            exp.DataType.Type.DATETIME
1225            if isinstance(date, datetime.datetime)
1226            else exp.DataType.Type.DATE
1227        )
1228
1229    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1232def interval(unit: str, n: int = 1):
1233    from dateutil.relativedelta import relativedelta
1234
1235    if unit == "year":
1236        return relativedelta(years=1 * n)
1237    if unit == "quarter":
1238        return relativedelta(months=3 * n)
1239    if unit == "month":
1240        return relativedelta(months=1 * n)
1241    if unit == "week":
1242        return relativedelta(weeks=1 * n)
1243    if unit == "day":
1244        return relativedelta(days=1 * n)
1245    if unit == "hour":
1246        return relativedelta(hours=1 * n)
1247    if unit == "minute":
1248        return relativedelta(minutes=1 * n)
1249    if unit == "second":
1250        return relativedelta(seconds=1 * n)
1251
1252    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1255def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1256    if unit == "year":
1257        return d.replace(month=1, day=1)
1258    if unit == "quarter":
1259        if d.month <= 3:
1260            return d.replace(month=1, day=1)
1261        elif d.month <= 6:
1262            return d.replace(month=4, day=1)
1263        elif d.month <= 9:
1264            return d.replace(month=7, day=1)
1265        else:
1266            return d.replace(month=10, day=1)
1267    if unit == "month":
1268        return d.replace(month=d.month, day=1)
1269    if unit == "week":
1270        # Assuming week starts on Monday (0) and ends on Sunday (6)
1271        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1272    if unit == "day":
1273        return d
1274
1275    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1278def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1279    floor = date_floor(d, unit, dialect)
1280
1281    if floor == d:
1282        return d
1283
1284    return floor + interval(unit)
def boolean_literal(condition):
1287def boolean_literal(condition):
1288    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1317def gen(expression: t.Any) -> str:
1318    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1319
1320    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1321    generator is expensive so we have a bare minimum sql generator here.
1322    """
1323    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1326class Gen:
1327    def __init__(self):
1328        self.stack = []
1329        self.sqls = []
1330
1331    def gen(self, expression: exp.Expression) -> str:
1332        self.stack = [expression]
1333        self.sqls.clear()
1334
1335        while self.stack:
1336            node = self.stack.pop()
1337
1338            if isinstance(node, exp.Expression):
1339                exp_handler_name = f"{node.key}_sql"
1340
1341                if hasattr(self, exp_handler_name):
1342                    getattr(self, exp_handler_name)(node)
1343                elif isinstance(node, exp.Func):
1344                    self._function(node)
1345                else:
1346                    key = node.key.upper()
1347                    self.stack.append(f"{key} " if self._args(node) else key)
1348            elif type(node) is list:
1349                for n in reversed(node):
1350                    if n is not None:
1351                        self.stack.extend((n, ","))
1352                if node:
1353                    self.stack.pop()
1354            else:
1355                if node is not None:
1356                    self.sqls.append(str(node))
1357
1358        return "".join(self.sqls)
1359
1360    def add_sql(self, e: exp.Add) -> None:
1361        self._binary(e, " + ")
1362
1363    def alias_sql(self, e: exp.Alias) -> None:
1364        self.stack.extend(
1365            (
1366                e.args.get("alias"),
1367                " AS ",
1368                e.args.get("this"),
1369            )
1370        )
1371
1372    def and_sql(self, e: exp.And) -> None:
1373        self._binary(e, " AND ")
1374
1375    def anonymous_sql(self, e: exp.Anonymous) -> None:
1376        this = e.this
1377        if isinstance(this, str):
1378            name = this.upper()
1379        elif isinstance(this, exp.Identifier):
1380            name = this.this
1381            name = f'"{name}"' if this.quoted else name.upper()
1382        else:
1383            raise ValueError(
1384                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1385            )
1386
1387        self.stack.extend(
1388            (
1389                ")",
1390                e.expressions,
1391                "(",
1392                name,
1393            )
1394        )
1395
1396    def between_sql(self, e: exp.Between) -> None:
1397        self.stack.extend(
1398            (
1399                e.args.get("high"),
1400                " AND ",
1401                e.args.get("low"),
1402                " BETWEEN ",
1403                e.this,
1404            )
1405        )
1406
1407    def boolean_sql(self, e: exp.Boolean) -> None:
1408        self.stack.append("TRUE" if e.this else "FALSE")
1409
1410    def bracket_sql(self, e: exp.Bracket) -> None:
1411        self.stack.extend(
1412            (
1413                "]",
1414                e.expressions,
1415                "[",
1416                e.this,
1417            )
1418        )
1419
1420    def column_sql(self, e: exp.Column) -> None:
1421        for p in reversed(e.parts):
1422            self.stack.extend((p, "."))
1423        self.stack.pop()
1424
1425    def datatype_sql(self, e: exp.DataType) -> None:
1426        self._args(e, 1)
1427        self.stack.append(f"{e.this.name} ")
1428
1429    def div_sql(self, e: exp.Div) -> None:
1430        self._binary(e, " / ")
1431
1432    def dot_sql(self, e: exp.Dot) -> None:
1433        self._binary(e, ".")
1434
1435    def eq_sql(self, e: exp.EQ) -> None:
1436        self._binary(e, " = ")
1437
1438    def from_sql(self, e: exp.From) -> None:
1439        self.stack.extend((e.this, "FROM "))
1440
1441    def gt_sql(self, e: exp.GT) -> None:
1442        self._binary(e, " > ")
1443
1444    def gte_sql(self, e: exp.GTE) -> None:
1445        self._binary(e, " >= ")
1446
1447    def identifier_sql(self, e: exp.Identifier) -> None:
1448        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1449
1450    def ilike_sql(self, e: exp.ILike) -> None:
1451        self._binary(e, " ILIKE ")
1452
1453    def in_sql(self, e: exp.In) -> None:
1454        self.stack.append(")")
1455        self._args(e, 1)
1456        self.stack.extend(
1457            (
1458                "(",
1459                " IN ",
1460                e.this,
1461            )
1462        )
1463
1464    def intdiv_sql(self, e: exp.IntDiv) -> None:
1465        self._binary(e, " DIV ")
1466
1467    def is_sql(self, e: exp.Is) -> None:
1468        self._binary(e, " IS ")
1469
1470    def like_sql(self, e: exp.Like) -> None:
1471        self._binary(e, " Like ")
1472
1473    def literal_sql(self, e: exp.Literal) -> None:
1474        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1475
1476    def lt_sql(self, e: exp.LT) -> None:
1477        self._binary(e, " < ")
1478
1479    def lte_sql(self, e: exp.LTE) -> None:
1480        self._binary(e, " <= ")
1481
1482    def mod_sql(self, e: exp.Mod) -> None:
1483        self._binary(e, " % ")
1484
1485    def mul_sql(self, e: exp.Mul) -> None:
1486        self._binary(e, " * ")
1487
1488    def neg_sql(self, e: exp.Neg) -> None:
1489        self._unary(e, "-")
1490
1491    def neq_sql(self, e: exp.NEQ) -> None:
1492        self._binary(e, " <> ")
1493
1494    def not_sql(self, e: exp.Not) -> None:
1495        self._unary(e, "NOT ")
1496
1497    def null_sql(self, e: exp.Null) -> None:
1498        self.stack.append("NULL")
1499
1500    def or_sql(self, e: exp.Or) -> None:
1501        self._binary(e, " OR ")
1502
1503    def paren_sql(self, e: exp.Paren) -> None:
1504        self.stack.extend(
1505            (
1506                ")",
1507                e.this,
1508                "(",
1509            )
1510        )
1511
1512    def sub_sql(self, e: exp.Sub) -> None:
1513        self._binary(e, " - ")
1514
1515    def subquery_sql(self, e: exp.Subquery) -> None:
1516        self._args(e, 2)
1517        alias = e.args.get("alias")
1518        if alias:
1519            self.stack.append(alias)
1520        self.stack.extend((")", e.this, "("))
1521
1522    def table_sql(self, e: exp.Table) -> None:
1523        self._args(e, 4)
1524        alias = e.args.get("alias")
1525        if alias:
1526            self.stack.append(alias)
1527        for p in reversed(e.parts):
1528            self.stack.extend((p, "."))
1529        self.stack.pop()
1530
1531    def tablealias_sql(self, e: exp.TableAlias) -> None:
1532        columns = e.columns
1533
1534        if columns:
1535            self.stack.extend((")", columns, "("))
1536
1537        self.stack.extend((e.this, " AS "))
1538
1539    def var_sql(self, e: exp.Var) -> None:
1540        self.stack.append(e.this)
1541
1542    def _binary(self, e: exp.Binary, op: str) -> None:
1543        self.stack.extend((e.expression, op, e.this))
1544
1545    def _unary(self, e: exp.Unary, op: str) -> None:
1546        self.stack.extend((e.this, op))
1547
1548    def _function(self, e: exp.Func) -> None:
1549        self.stack.extend(
1550            (
1551                ")",
1552                list(e.args.values()),
1553                "(",
1554                e.sql_name(),
1555            )
1556        )
1557
1558    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1559        kvs = []
1560        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1561
1562        for k in arg_types or arg_types:
1563            v = node.args.get(k)
1564
1565            if v is not None:
1566                kvs.append([f":{k}", v])
1567        if kvs:
1568            self.stack.append(kvs)
1569            return True
1570        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1331    def gen(self, expression: exp.Expression) -> str:
1332        self.stack = [expression]
1333        self.sqls.clear()
1334
1335        while self.stack:
1336            node = self.stack.pop()
1337
1338            if isinstance(node, exp.Expression):
1339                exp_handler_name = f"{node.key}_sql"
1340
1341                if hasattr(self, exp_handler_name):
1342                    getattr(self, exp_handler_name)(node)
1343                elif isinstance(node, exp.Func):
1344                    self._function(node)
1345                else:
1346                    key = node.key.upper()
1347                    self.stack.append(f"{key} " if self._args(node) else key)
1348            elif type(node) is list:
1349                for n in reversed(node):
1350                    if n is not None:
1351                        self.stack.extend((n, ","))
1352                if node:
1353                    self.stack.pop()
1354            else:
1355                if node is not None:
1356                    self.sqls.append(str(node))
1357
1358        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1360    def add_sql(self, e: exp.Add) -> None:
1361        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1363    def alias_sql(self, e: exp.Alias) -> None:
1364        self.stack.extend(
1365            (
1366                e.args.get("alias"),
1367                " AS ",
1368                e.args.get("this"),
1369            )
1370        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1372    def and_sql(self, e: exp.And) -> None:
1373        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1375    def anonymous_sql(self, e: exp.Anonymous) -> None:
1376        this = e.this
1377        if isinstance(this, str):
1378            name = this.upper()
1379        elif isinstance(this, exp.Identifier):
1380            name = this.this
1381            name = f'"{name}"' if this.quoted else name.upper()
1382        else:
1383            raise ValueError(
1384                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1385            )
1386
1387        self.stack.extend(
1388            (
1389                ")",
1390                e.expressions,
1391                "(",
1392                name,
1393            )
1394        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1396    def between_sql(self, e: exp.Between) -> None:
1397        self.stack.extend(
1398            (
1399                e.args.get("high"),
1400                " AND ",
1401                e.args.get("low"),
1402                " BETWEEN ",
1403                e.this,
1404            )
1405        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1407    def boolean_sql(self, e: exp.Boolean) -> None:
1408        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1410    def bracket_sql(self, e: exp.Bracket) -> None:
1411        self.stack.extend(
1412            (
1413                "]",
1414                e.expressions,
1415                "[",
1416                e.this,
1417            )
1418        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1420    def column_sql(self, e: exp.Column) -> None:
1421        for p in reversed(e.parts):
1422            self.stack.extend((p, "."))
1423        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1425    def datatype_sql(self, e: exp.DataType) -> None:
1426        self._args(e, 1)
1427        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1429    def div_sql(self, e: exp.Div) -> None:
1430        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1432    def dot_sql(self, e: exp.Dot) -> None:
1433        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1435    def eq_sql(self, e: exp.EQ) -> None:
1436        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1438    def from_sql(self, e: exp.From) -> None:
1439        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1441    def gt_sql(self, e: exp.GT) -> None:
1442        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1444    def gte_sql(self, e: exp.GTE) -> None:
1445        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1447    def identifier_sql(self, e: exp.Identifier) -> None:
1448        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1450    def ilike_sql(self, e: exp.ILike) -> None:
1451        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1453    def in_sql(self, e: exp.In) -> None:
1454        self.stack.append(")")
1455        self._args(e, 1)
1456        self.stack.extend(
1457            (
1458                "(",
1459                " IN ",
1460                e.this,
1461            )
1462        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1464    def intdiv_sql(self, e: exp.IntDiv) -> None:
1465        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1467    def is_sql(self, e: exp.Is) -> None:
1468        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1470    def like_sql(self, e: exp.Like) -> None:
1471        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1473    def literal_sql(self, e: exp.Literal) -> None:
1474        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1476    def lt_sql(self, e: exp.LT) -> None:
1477        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1479    def lte_sql(self, e: exp.LTE) -> None:
1480        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1482    def mod_sql(self, e: exp.Mod) -> None:
1483        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1485    def mul_sql(self, e: exp.Mul) -> None:
1486        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1488    def neg_sql(self, e: exp.Neg) -> None:
1489        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1491    def neq_sql(self, e: exp.NEQ) -> None:
1492        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1494    def not_sql(self, e: exp.Not) -> None:
1495        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1497    def null_sql(self, e: exp.Null) -> None:
1498        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1500    def or_sql(self, e: exp.Or) -> None:
1501        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1503    def paren_sql(self, e: exp.Paren) -> None:
1504        self.stack.extend(
1505            (
1506                ")",
1507                e.this,
1508                "(",
1509            )
1510        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1512    def sub_sql(self, e: exp.Sub) -> None:
1513        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1515    def subquery_sql(self, e: exp.Subquery) -> None:
1516        self._args(e, 2)
1517        alias = e.args.get("alias")
1518        if alias:
1519            self.stack.append(alias)
1520        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1522    def table_sql(self, e: exp.Table) -> None:
1523        self._args(e, 4)
1524        alias = e.args.get("alias")
1525        if alias:
1526            self.stack.append(alias)
1527        for p in reversed(e.parts):
1528            self.stack.extend((p, "."))
1529        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1531    def tablealias_sql(self, e: exp.TableAlias) -> None:
1532        columns = e.columns
1533
1534        if columns:
1535            self.stack.extend((")", columns, "("))
1536
1537        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1539    def var_sql(self, e: exp.Var) -> None:
1540        self.stack.append(e.this)